• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 6843553218

12 Nov 2023 10:39PM UTC coverage: 90.178% (-0.03%) from 90.211%
6843553218

Pull #618

github

web-flow
Merge c884bbc44 into 2fdb1771d
Pull Request #618: Run all sync/async fixture/test code under the same contextvars.Context

136 of 151 new or added lines in 2 files covered. (90.07%)

2 existing lines in 1 file now uncovered.

4398 of 4877 relevant lines covered (90.18%)

8.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.5
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextlib import suppress
10✔
25
from contextvars import Context, copy_context
10✔
26
from dataclasses import dataclass
10✔
27
from functools import partial, wraps
10✔
28
from inspect import (
10✔
29
    CORO_RUNNING,
30
    CORO_SUSPENDED,
31
    getcoroutinestate,
32
    iscoroutine,
33
)
34
from io import IOBase
10✔
35
from os import PathLike
10✔
36
from queue import Queue
10✔
37
from signal import Signals
10✔
38
from socket import AddressFamily, SocketKind
10✔
39
from threading import Thread
10✔
40
from types import TracebackType
10✔
41
from typing import (
10✔
42
    IO,
43
    Any,
44
    AsyncGenerator,
45
    Awaitable,
46
    Callable,
47
    Collection,
48
    ContextManager,
49
    Coroutine,
50
    Mapping,
51
    Optional,
52
    Sequence,
53
    Tuple,
54
    TypeVar,
55
    cast,
56
)
57
from weakref import WeakKeyDictionary
10✔
58

59
import sniffio
10✔
60

61
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
62
from .._core._eventloop import claim_worker_thread
10✔
63
from .._core._exceptions import (
10✔
64
    BrokenResourceError,
65
    BusyResourceError,
66
    ClosedResourceError,
67
    EndOfStream,
68
    WouldBlock,
69
)
70
from .._core._sockets import convert_ipv6_sockaddr
10✔
71
from .._core._streams import create_memory_object_stream
10✔
72
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
73
from .._core._synchronization import Event as BaseEvent
10✔
74
from .._core._synchronization import ResourceGuard
10✔
75
from .._core._tasks import CancelScope as BaseCancelScope
10✔
76
from ..abc import (
10✔
77
    AsyncBackend,
78
    IPSockAddrType,
79
    SocketListener,
80
    UDPPacketType,
81
    UNIXDatagramPacketType,
82
)
83
from ..lowlevel import RunVar
10✔
84
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
85

86
if sys.version_info >= (3, 11):
10✔
87
    from asyncio import Runner
5✔
88
else:
89
    import contextvars
6✔
90
    import enum
6✔
91
    import signal
6✔
92
    from asyncio import coroutines, events, exceptions, tasks
6✔
93

94
    from exceptiongroup import BaseExceptionGroup
6✔
95

96
    class _State(enum.Enum):
6✔
97
        CREATED = "created"
6✔
98
        INITIALIZED = "initialized"
6✔
99
        CLOSED = "closed"
6✔
100

101
    class Runner:
6✔
102
        # Copied from CPython 3.11
103
        def __init__(
6✔
104
            self,
105
            *,
106
            debug: bool | None = None,
107
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
108
        ):
109
            self._state = _State.CREATED
6✔
110
            self._debug = debug
6✔
111
            self._loop_factory = loop_factory
6✔
112
            self._loop: AbstractEventLoop | None = None
6✔
113
            self._context = None
6✔
114
            self._interrupt_count = 0
6✔
115
            self._set_event_loop = False
6✔
116

117
        def __enter__(self) -> Runner:
6✔
118
            self._lazy_init()
6✔
119
            return self
6✔
120

121
        def __exit__(
6✔
122
            self,
123
            exc_type: type[BaseException],
124
            exc_val: BaseException,
125
            exc_tb: TracebackType,
126
        ) -> None:
127
            self.close()
6✔
128

129
        def close(self) -> None:
6✔
130
            """Shutdown and close event loop."""
131
            if self._state is not _State.INITIALIZED:
6✔
132
                return
×
133
            try:
6✔
134
                loop = self._loop
6✔
135
                _cancel_all_tasks(loop)
6✔
136
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
137
                if hasattr(loop, "shutdown_default_executor"):
6✔
138
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
139
                else:
140
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
141
            finally:
142
                if self._set_event_loop:
6✔
143
                    events.set_event_loop(None)
6✔
144
                loop.close()
6✔
145
                self._loop = None
6✔
146
                self._state = _State.CLOSED
6✔
147

148
        def get_loop(self) -> AbstractEventLoop:
6✔
149
            """Return embedded event loop."""
150
            self._lazy_init()
6✔
151
            return self._loop
6✔
152

153
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
154
            """Run a coroutine inside the embedded event loop."""
155
            if not coroutines.iscoroutine(coro):
×
156
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
157

158
            if events._get_running_loop() is not None:
×
159
                # fail fast with short traceback
160
                raise RuntimeError(
×
161
                    "Runner.run() cannot be called from a running event loop"
162
                )
163

164
            self._lazy_init()
×
165

166
            if context is None:
×
167
                context = self._context
×
168
            task = self._loop.create_task(coro, context=context)
×
169

170
            if (
×
171
                threading.current_thread() is threading.main_thread()
172
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
173
            ):
174
                sigint_handler = partial(self._on_sigint, main_task=task)
×
175
                try:
×
176
                    signal.signal(signal.SIGINT, sigint_handler)
×
177
                except ValueError:
×
178
                    # `signal.signal` may throw if `threading.main_thread` does
179
                    # not support signals (e.g. embedded interpreter with signals
180
                    # not registered - see gh-91880)
181
                    sigint_handler = None
×
182
            else:
183
                sigint_handler = None
×
184

185
            self._interrupt_count = 0
×
186
            try:
×
187
                return self._loop.run_until_complete(task)
×
188
            except exceptions.CancelledError:
×
189
                if self._interrupt_count > 0:
×
190
                    uncancel = getattr(task, "uncancel", None)
×
191
                    if uncancel is not None and uncancel() == 0:
×
192
                        raise KeyboardInterrupt()
×
UNCOV
193
                raise  # CancelledError
×
194
            finally:
195
                if (
×
196
                    sigint_handler is not None
197
                    and signal.getsignal(signal.SIGINT) is sigint_handler
198
                ):
199
                    signal.signal(signal.SIGINT, signal.default_int_handler)
×
200

201
        def _lazy_init(self) -> None:
6✔
202
            if self._state is _State.CLOSED:
6✔
203
                raise RuntimeError("Runner is closed")
×
204
            if self._state is _State.INITIALIZED:
6✔
205
                return
6✔
206
            if self._loop_factory is None:
6✔
207
                self._loop = events.new_event_loop()
6✔
208
                if not self._set_event_loop:
6✔
209
                    # Call set_event_loop only once to avoid calling
210
                    # attach_loop multiple times on child watchers
211
                    events.set_event_loop(self._loop)
6✔
212
                    self._set_event_loop = True
6✔
213
            else:
214
                self._loop = self._loop_factory()
4✔
215
            if self._debug is not None:
6✔
216
                self._loop.set_debug(self._debug)
6✔
217
            self._context = contextvars.copy_context()
6✔
218
            self._state = _State.INITIALIZED
6✔
219

220
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
221
            self._interrupt_count += 1
×
222
            if self._interrupt_count == 1 and not main_task.done():
×
223
                main_task.cancel()
2✔
224
                # wakeup loop if it is blocked by select() with long timeout
225
                self._loop.call_soon_threadsafe(lambda: None)
×
226
                return
×
227
            raise KeyboardInterrupt()
×
228

229
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
230
        to_cancel = tasks.all_tasks(loop)
6✔
231
        if not to_cancel:
6✔
232
            return
×
233

234
        for task in to_cancel:
6✔
235
            task.cancel()
6✔
236

237
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
238

239
        for task in to_cancel:
6✔
240
            if task.cancelled():
6✔
241
                continue
6✔
242
            if task.exception() is not None:
5✔
UNCOV
243
                loop.call_exception_handler(
×
244
                    {
245
                        "message": "unhandled exception during asyncio.run() shutdown",
246
                        "exception": task.exception(),
247
                        "task": task,
248
                    }
249
                )
250

251
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
252
        """Schedule the shutdown of the default executor."""
253

254
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
255
            try:
3✔
256
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
257
                loop.call_soon_threadsafe(future.set_result, None)
3✔
258
            except Exception as ex:
×
259
                loop.call_soon_threadsafe(future.set_exception, ex)
×
260

261
        loop._executor_shutdown_called = True
3✔
262
        if loop._default_executor is None:
3✔
263
            return
3✔
264
        future = loop.create_future()
3✔
265
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
266
        thread.start()
3✔
267
        try:
3✔
268
            await future
3✔
269
        finally:
270
            thread.join()
3✔
271

272

273
T_Retval = TypeVar("T_Retval")
10✔
274
T_contra = TypeVar("T_contra", contravariant=True)
10✔
275

276
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
277

278

279
def find_root_task() -> asyncio.Task:
10✔
280
    root_task = _root_task.get(None)
10✔
281
    if root_task is not None and not root_task.done():
10✔
282
        return root_task
10✔
283

284
    # Look for a task that has been started via run_until_complete()
285
    for task in all_tasks():
10✔
286
        if task._callbacks and not task.done():
10✔
287
            callbacks = [cb for cb, context in task._callbacks]
10✔
288
            for cb in callbacks:
10✔
289
                if (
10✔
290
                    cb is _run_until_complete_cb
291
                    or getattr(cb, "__module__", None) == "uvloop.loop"
292
                ):
293
                    _root_task.set(task)
10✔
294
                    return task
10✔
295

296
    # Look up the topmost task in the AnyIO task tree, if possible
297
    task = cast(asyncio.Task, current_task())
9✔
298
    state = _task_states.get(task)
9✔
299
    if state:
9✔
300
        cancel_scope = state.cancel_scope
9✔
301
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
302
            cancel_scope = cancel_scope._parent_scope
×
303

304
        if cancel_scope is not None:
9✔
305
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
306

307
    return task
×
308

309

310
def get_callable_name(func: Callable) -> str:
10✔
311
    module = getattr(func, "__module__", None)
10✔
312
    qualname = getattr(func, "__qualname__", None)
10✔
313
    return ".".join([x for x in (module, qualname) if x])
10✔
314

315

316
#
317
# Event loop
318
#
319

320
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
10✔
321

322

323
def _task_started(task: asyncio.Task) -> bool:
10✔
324
    """Return ``True`` if the task has been started and has not finished."""
325
    try:
10✔
326
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
327
    except AttributeError:
×
328
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
329
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
330

331

332
#
333
# Timeouts and cancellation
334
#
335

336

337
class CancelScope(BaseCancelScope):
10✔
338
    def __new__(
10✔
339
        cls, *, deadline: float = math.inf, shield: bool = False
340
    ) -> CancelScope:
341
        return object.__new__(cls)
10✔
342

343
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
344
        self._deadline = deadline
10✔
345
        self._shield = shield
10✔
346
        self._parent_scope: CancelScope | None = None
10✔
347
        self._cancel_called = False
10✔
348
        self._cancelled_caught = False
10✔
349
        self._active = False
10✔
350
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
351
        self._cancel_handle: asyncio.Handle | None = None
10✔
352
        self._tasks: set[asyncio.Task] = set()
10✔
353
        self._host_task: asyncio.Task | None = None
10✔
354
        self._cancel_calls: int = 0
10✔
355
        self._cancelling: int | None = None
10✔
356

357
    def __enter__(self) -> CancelScope:
10✔
358
        if self._active:
10✔
359
            raise RuntimeError(
×
360
                "Each CancelScope may only be used for a single 'with' block"
361
            )
362

363
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
364
        self._tasks.add(host_task)
10✔
365
        try:
10✔
366
            task_state = _task_states[host_task]
10✔
367
        except KeyError:
10✔
368
            task_state = TaskState(None, self)
10✔
369
            _task_states[host_task] = task_state
10✔
370
        else:
371
            self._parent_scope = task_state.cancel_scope
10✔
372
            task_state.cancel_scope = self
10✔
373

374
        self._timeout()
10✔
375
        self._active = True
10✔
376
        if sys.version_info >= (3, 11):
10✔
377
            self._cancelling = self._host_task.cancelling()
4✔
378

379
        # Start cancelling the host task if the scope was cancelled before entering
380
        if self._cancel_called:
10✔
381
            self._deliver_cancellation()
10✔
382

383
        return self
10✔
384

385
    def __exit__(
10✔
386
        self,
387
        exc_type: type[BaseException] | None,
388
        exc_val: BaseException | None,
389
        exc_tb: TracebackType | None,
390
    ) -> bool | None:
391
        if not self._active:
10✔
392
            raise RuntimeError("This cancel scope is not active")
9✔
393
        if current_task() is not self._host_task:
10✔
394
            raise RuntimeError(
9✔
395
                "Attempted to exit cancel scope in a different task than it was "
396
                "entered in"
397
            )
398

399
        assert self._host_task is not None
10✔
400
        host_task_state = _task_states.get(self._host_task)
10✔
401
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
402
            raise RuntimeError(
9✔
403
                "Attempted to exit a cancel scope that isn't the current tasks's "
404
                "current cancel scope"
405
            )
406

407
        self._active = False
10✔
408
        if self._timeout_handle:
10✔
409
            self._timeout_handle.cancel()
10✔
410
            self._timeout_handle = None
10✔
411

412
        self._tasks.remove(self._host_task)
10✔
413

414
        host_task_state.cancel_scope = self._parent_scope
10✔
415

416
        # Restart the cancellation effort in the farthest directly cancelled parent
417
        # scope if this one was shielded
418
        if self._shield:
10✔
419
            self._deliver_cancellation_to_parent()
10✔
420

421
        if isinstance(exc_val, CancelledError) and self._cancel_called:
10✔
422
            self._cancelled_caught = self._uncancel(exc_val)
10✔
423
            return self._cancelled_caught
10✔
424

425
        return None
10✔
426

427
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
10✔
428
        if sys.version_info < (3, 9) or self._host_task is None:
10✔
429
            self._cancel_calls = 0
3✔
430
            return True
3✔
431

432
        # Undo all cancellations done by this scope
433
        if self._cancelling is not None:
7✔
434
            while self._cancel_calls:
4✔
435
                self._cancel_calls -= 1
4✔
436
                if self._host_task.uncancel() <= self._cancelling:
4✔
437
                    return True
4✔
438

439
        self._cancel_calls = 0
7✔
440
        return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
7✔
441

442
    def _timeout(self) -> None:
10✔
443
        if self._deadline != math.inf:
10✔
444
            loop = get_running_loop()
10✔
445
            if loop.time() >= self._deadline:
10✔
446
                self.cancel()
10✔
447
            else:
448
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
449

450
    def _deliver_cancellation(self) -> None:
10✔
451
        """
452
        Deliver cancellation to directly contained tasks and nested cancel scopes.
453

454
        Schedule another run at the end if we still have tasks eligible for
455
        cancellation.
456
        """
457
        should_retry = False
10✔
458
        current = current_task()
10✔
459
        for task in self._tasks:
10✔
460
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
461
                continue
9✔
462

463
            # The task is eligible for cancellation if it has started and is not in a
464
            # cancel scope shielded from this one
465
            cancel_scope = _task_states[task].cancel_scope
10✔
466
            while cancel_scope is not self:
10✔
467
                if cancel_scope is None or cancel_scope._shield:
10✔
468
                    break
10✔
469
                else:
470
                    cancel_scope = cancel_scope._parent_scope
10✔
471
            else:
472
                should_retry = True
10✔
473
                if task is not current and (
10✔
474
                    task is self._host_task or _task_started(task)
475
                ):
476
                    waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
477
                    if not isinstance(waiter, asyncio.Future) or not waiter.done():
10✔
478
                        self._cancel_calls += 1
10✔
479
                        if sys.version_info >= (3, 9):
10✔
480
                            task.cancel(f"Cancelled by cancel scope {id(self):x}")
7✔
481
                        else:
482
                            task.cancel()
3✔
483

484
        # Schedule another callback if there are still tasks left
485
        if should_retry:
10✔
486
            self._cancel_handle = get_running_loop().call_soon(
10✔
487
                self._deliver_cancellation
488
            )
489
        else:
490
            self._cancel_handle = None
10✔
491

492
    def _deliver_cancellation_to_parent(self) -> None:
10✔
493
        """Start cancellation effort in the farthest directly cancelled parent scope"""
494
        scope = self._parent_scope
10✔
495
        scope_to_cancel: CancelScope | None = None
10✔
496
        while scope is not None:
10✔
497
            if scope._cancel_called and scope._cancel_handle is None:
10✔
498
                scope_to_cancel = scope
10✔
499

500
            # No point in looking beyond any shielded scope
501
            if scope._shield:
10✔
502
                break
9✔
503

504
            scope = scope._parent_scope
10✔
505

506
        if scope_to_cancel is not None:
10✔
507
            scope_to_cancel._deliver_cancellation()
10✔
508

509
    def _parent_cancelled(self) -> bool:
10✔
510
        # Check whether any parent has been cancelled
511
        cancel_scope = self._parent_scope
×
512
        while cancel_scope is not None and not cancel_scope._shield:
×
513
            if cancel_scope._cancel_called:
×
514
                return True
×
515
            else:
516
                cancel_scope = cancel_scope._parent_scope
×
517

518
        return False
×
519

520
    def cancel(self) -> None:
10✔
521
        if not self._cancel_called:
10✔
522
            if self._timeout_handle:
10✔
523
                self._timeout_handle.cancel()
10✔
524
                self._timeout_handle = None
10✔
525

526
            self._cancel_called = True
10✔
527
            if self._host_task is not None:
10✔
528
                self._deliver_cancellation()
10✔
529

530
    @property
10✔
531
    def deadline(self) -> float:
10✔
532
        return self._deadline
9✔
533

534
    @deadline.setter
10✔
535
    def deadline(self, value: float) -> None:
10✔
536
        self._deadline = float(value)
9✔
537
        if self._timeout_handle is not None:
9✔
538
            self._timeout_handle.cancel()
9✔
539
            self._timeout_handle = None
9✔
540

541
        if self._active and not self._cancel_called:
9✔
542
            self._timeout()
9✔
543

544
    @property
10✔
545
    def cancel_called(self) -> bool:
10✔
546
        return self._cancel_called
10✔
547

548
    @property
10✔
549
    def cancelled_caught(self) -> bool:
10✔
550
        return self._cancelled_caught
10✔
551

552
    @property
10✔
553
    def shield(self) -> bool:
10✔
554
        return self._shield
10✔
555

556
    @shield.setter
10✔
557
    def shield(self, value: bool) -> None:
10✔
558
        if self._shield != value:
9✔
559
            self._shield = value
9✔
560
            if not value:
9✔
561
                self._deliver_cancellation_to_parent()
9✔
562

563

564
#
565
# Task states
566
#
567

568

569
class TaskState:
10✔
570
    """
571
    Encapsulates auxiliary task information that cannot be added to the Task instance
572
    itself because there are no guarantees about its implementation.
573
    """
574

575
    __slots__ = "parent_id", "cancel_scope"
10✔
576

577
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
578
        self.parent_id = parent_id
10✔
579
        self.cancel_scope = cancel_scope
10✔
580

581

582
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
583

584

585
#
586
# Task groups
587
#
588

589

590
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
591
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
592
        self._future = future
10✔
593
        self._parent_id = parent_id
10✔
594

595
    def started(self, value: T_contra | None = None) -> None:
10✔
596
        try:
10✔
597
            self._future.set_result(value)
10✔
598
        except asyncio.InvalidStateError:
9✔
599
            raise RuntimeError(
9✔
600
                "called 'started' twice on the same task status"
601
            ) from None
602

603
        task = cast(asyncio.Task, current_task())
10✔
604
        _task_states[task].parent_id = self._parent_id
10✔
605

606

607
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
608
    exceptions = list(excgroup.exceptions)
×
609
    modified = False
×
610
    for i, exc in enumerate(exceptions):
×
611
        if isinstance(exc, BaseExceptionGroup):
×
612
            new_exc = collapse_exception_group(exc)
×
613
            if new_exc is not exc:
×
614
                modified = True
×
615
                exceptions[i] = new_exc
×
616

617
    if len(exceptions) == 1:
×
618
        return exceptions[0]
×
619
    elif modified:
×
620
        return excgroup.derive(exceptions)
×
621
    else:
622
        return excgroup
×
623

624

625
class TaskGroup(abc.TaskGroup):
10✔
626
    def __init__(self) -> None:
10✔
627
        self.cancel_scope: CancelScope = CancelScope()
10✔
628
        self._active = False
10✔
629
        self._exceptions: list[BaseException] = []
10✔
630

631
    async def __aenter__(self) -> TaskGroup:
10✔
632
        self.cancel_scope.__enter__()
10✔
633
        self._active = True
10✔
634
        return self
10✔
635

636
    async def __aexit__(
10✔
637
        self,
638
        exc_type: type[BaseException] | None,
639
        exc_val: BaseException | None,
640
        exc_tb: TracebackType | None,
641
    ) -> bool | None:
642
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
643
        if exc_val is not None:
10✔
644
            self.cancel_scope.cancel()
10✔
645
            if not isinstance(exc_val, CancelledError):
10✔
646
                self._exceptions.append(exc_val)
10✔
647

648
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
10✔
649
        while self.cancel_scope._tasks:
10✔
650
            try:
10✔
651
                await asyncio.wait(self.cancel_scope._tasks)
10✔
652
            except CancelledError as exc:
10✔
653
                # This task was cancelled natively; reraise the CancelledError later
654
                # unless this task was already interrupted by another exception
655
                self.cancel_scope.cancel()
10✔
656
                if cancelled_exc_while_waiting_tasks is None:
10✔
657
                    cancelled_exc_while_waiting_tasks = exc
10✔
658

659
        self._active = False
10✔
660
        if self._exceptions:
10✔
661
            raise BaseExceptionGroup(
10✔
662
                "unhandled errors in a TaskGroup", self._exceptions
663
            )
664

665
        # Raise the CancelledError received while waiting for child tasks to exit,
666
        # unless the context manager itself was previously exited with another
667
        # exception, or if any of the  child tasks raised an exception other than
668
        # CancelledError
669
        if cancelled_exc_while_waiting_tasks:
10✔
670
            if exc_val is None or ignore_exception:
10✔
671
                raise cancelled_exc_while_waiting_tasks
10✔
672

673
        return ignore_exception
10✔
674

675
    def _spawn(
10✔
676
        self,
677
        func: Callable[..., Awaitable[Any]],
678
        args: tuple,
679
        name: object,
680
        task_status_future: asyncio.Future | None = None,
681
    ) -> asyncio.Task:
682
        def task_done(_task: asyncio.Task) -> None:
10✔
683
            assert _task in self.cancel_scope._tasks
10✔
684
            self.cancel_scope._tasks.remove(_task)
10✔
685
            del _task_states[_task]
10✔
686

687
            try:
10✔
688
                exc = _task.exception()
10✔
689
            except CancelledError as e:
10✔
690
                while isinstance(e.__context__, CancelledError):
10✔
691
                    e = e.__context__
3✔
692

693
                exc = e
10✔
694

695
            if exc is not None:
10✔
696
                if task_status_future is None or task_status_future.done():
10✔
697
                    if not isinstance(exc, CancelledError):
10✔
698
                        self._exceptions.append(exc)
10✔
699

700
                    self.cancel_scope.cancel()
10✔
701
                else:
702
                    task_status_future.set_exception(exc)
9✔
703
            elif task_status_future is not None and not task_status_future.done():
10✔
704
                task_status_future.set_exception(
9✔
705
                    RuntimeError("Child exited without calling task_status.started()")
706
                )
707

708
        if not self._active:
10✔
709
            raise RuntimeError(
10✔
710
                "This task group is not active; no new tasks can be started."
711
            )
712

713
        kwargs = {}
10✔
714
        if task_status_future:
10✔
715
            parent_id = id(current_task())
10✔
716
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
717
                task_status_future, id(self.cancel_scope._host_task)
718
            )
719
        else:
720
            parent_id = id(self.cancel_scope._host_task)
10✔
721

722
        coro = func(*args, **kwargs)
10✔
723
        if not iscoroutine(coro):
10✔
724
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
725
            raise TypeError(
9✔
726
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
727
                f"the return value ({coro!r}) is not a coroutine object"
728
            )
729

730
        name = get_callable_name(func) if name is None else str(name)
10✔
731
        task = create_task(coro, name=name)
10✔
732
        task.add_done_callback(task_done)
10✔
733

734
        # Make the spawned task inherit the task group's cancel scope
735
        _task_states[task] = TaskState(
10✔
736
            parent_id=parent_id, cancel_scope=self.cancel_scope
737
        )
738
        self.cancel_scope._tasks.add(task)
10✔
739
        return task
10✔
740

741
    def start_soon(
10✔
742
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
743
    ) -> None:
744
        self._spawn(func, args, name)
10✔
745

746
    async def start(
10✔
747
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
748
    ) -> None:
749
        future: asyncio.Future = asyncio.Future()
10✔
750
        task = self._spawn(func, args, name, future)
10✔
751

752
        # If the task raises an exception after sending a start value without a switch
753
        # point between, the task group is cancelled and this method never proceeds to
754
        # process the completed future. That's why we have to have a shielded cancel
755
        # scope here.
756
        try:
10✔
757
            return await future
10✔
758
        except CancelledError:
9✔
759
            # Cancel the task and wait for it to exit before returning
760
            task.cancel()
9✔
761
            with CancelScope(shield=True), suppress(CancelledError):
9✔
762
                await task
9✔
763

764
            raise
9✔
765

766

767
#
768
# Threads
769
#
770

771
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
772

773

774
class WorkerThread(Thread):
10✔
775
    MAX_IDLE_TIME = 10  # seconds
10✔
776

777
    def __init__(
10✔
778
        self,
779
        root_task: asyncio.Task,
780
        workers: set[WorkerThread],
781
        idle_workers: deque[WorkerThread],
782
    ):
783
        super().__init__(name="AnyIO worker thread")
10✔
784
        self.root_task = root_task
10✔
785
        self.workers = workers
10✔
786
        self.idle_workers = idle_workers
10✔
787
        self.loop = root_task._loop
10✔
788
        self.queue: Queue[
10✔
789
            tuple[Context, Callable, tuple, asyncio.Future] | None
790
        ] = Queue(2)
791
        self.idle_since = AsyncIOBackend.current_time()
10✔
792
        self.stopping = False
10✔
793

794
    def _report_result(
10✔
795
        self, future: asyncio.Future, result: Any, exc: BaseException | None
796
    ) -> None:
797
        self.idle_since = AsyncIOBackend.current_time()
10✔
798
        if not self.stopping:
10✔
799
            self.idle_workers.append(self)
10✔
800

801
        if not future.cancelled():
10✔
802
            if exc is not None:
10✔
803
                if isinstance(exc, StopIteration):
10✔
804
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
805
                    new_exc.__cause__ = exc
9✔
806
                    exc = new_exc
9✔
807

808
                future.set_exception(exc)
10✔
809
            else:
810
                future.set_result(result)
10✔
811

812
    def run(self) -> None:
10✔
813
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
814
            while True:
6✔
815
                item = self.queue.get()
10✔
816
                if item is None:
10✔
817
                    # Shutdown command received
818
                    return
10✔
819

820
                context, func, args, future = item
10✔
821
                if not future.cancelled():
10✔
822
                    result = None
10✔
823
                    exception: BaseException | None = None
10✔
824
                    try:
10✔
825
                        result = context.run(func, *args)
10✔
826
                    except BaseException as exc:
10✔
827
                        exception = exc
10✔
828

829
                    if not self.loop.is_closed():
10✔
830
                        self.loop.call_soon_threadsafe(
10✔
831
                            self._report_result, future, result, exception
832
                        )
833

834
                self.queue.task_done()
10✔
835

836
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
837
        self.stopping = True
10✔
838
        self.queue.put_nowait(None)
10✔
839
        self.workers.discard(self)
10✔
840
        try:
10✔
841
            self.idle_workers.remove(self)
10✔
842
        except ValueError:
9✔
843
            pass
9✔
844

845

846
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
847
    "_threadpool_idle_workers"
848
)
849
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
850

851

852
class BlockingPortal(abc.BlockingPortal):
10✔
853
    def __new__(cls) -> BlockingPortal:
10✔
854
        return object.__new__(cls)
10✔
855

856
    def __init__(self) -> None:
10✔
857
        super().__init__()
10✔
858
        self._loop = get_running_loop()
10✔
859

860
    def _spawn_task_from_thread(
10✔
861
        self,
862
        func: Callable,
863
        args: tuple[Any, ...],
864
        kwargs: dict[str, Any],
865
        name: object,
866
        future: Future,
867
    ) -> None:
868
        AsyncIOBackend.run_sync_from_thread(
10✔
869
            partial(self._task_group.start_soon, name=name),
870
            (self._call_func, func, args, kwargs, future),
871
            self._loop,
872
        )
873

874

875
#
876
# Subprocesses
877
#
878

879

880
@dataclass(eq=False)
10✔
881
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
882
    _stream: asyncio.StreamReader
10✔
883

884
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
885
        data = await self._stream.read(max_bytes)
9✔
886
        if data:
9✔
887
            return data
9✔
888
        else:
889
            raise EndOfStream
9✔
890

891
    async def aclose(self) -> None:
10✔
892
        self._stream.feed_eof()
9✔
893

894

895
@dataclass(eq=False)
10✔
896
class StreamWriterWrapper(abc.ByteSendStream):
10✔
897
    _stream: asyncio.StreamWriter
10✔
898

899
    async def send(self, item: bytes) -> None:
10✔
900
        self._stream.write(item)
9✔
901
        await self._stream.drain()
9✔
902

903
    async def aclose(self) -> None:
10✔
904
        self._stream.close()
9✔
905

906

907
@dataclass(eq=False)
10✔
908
class Process(abc.Process):
10✔
909
    _process: asyncio.subprocess.Process
10✔
910
    _stdin: StreamWriterWrapper | None
10✔
911
    _stdout: StreamReaderWrapper | None
10✔
912
    _stderr: StreamReaderWrapper | None
10✔
913

914
    async def aclose(self) -> None:
10✔
915
        if self._stdin:
9✔
916
            await self._stdin.aclose()
9✔
917
        if self._stdout:
9✔
918
            await self._stdout.aclose()
9✔
919
        if self._stderr:
9✔
920
            await self._stderr.aclose()
9✔
921

922
        await self.wait()
9✔
923

924
    async def wait(self) -> int:
10✔
925
        return await self._process.wait()
9✔
926

927
    def terminate(self) -> None:
10✔
928
        self._process.terminate()
7✔
929

930
    def kill(self) -> None:
10✔
931
        self._process.kill()
9✔
932

933
    def send_signal(self, signal: int) -> None:
10✔
934
        self._process.send_signal(signal)
×
935

936
    @property
10✔
937
    def pid(self) -> int:
10✔
938
        return self._process.pid
×
939

940
    @property
10✔
941
    def returncode(self) -> int | None:
10✔
942
        return self._process.returncode
9✔
943

944
    @property
10✔
945
    def stdin(self) -> abc.ByteSendStream | None:
10✔
946
        return self._stdin
9✔
947

948
    @property
10✔
949
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
950
        return self._stdout
9✔
951

952
    @property
10✔
953
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
954
        return self._stderr
9✔
955

956

957
def _forcibly_shutdown_process_pool_on_exit(
10✔
958
    workers: set[Process], _task: object
959
) -> None:
960
    """
961
    Forcibly shuts down worker processes belonging to this event loop."""
962
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
963
    if sys.version_info < (3, 12):
9✔
964
        try:
8✔
965
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
966
        except NotImplementedError:
2✔
967
            pass
2✔
968

969
    # Close as much as possible (w/o async/await) to avoid warnings
970
    for process in workers:
9✔
971
        if process.returncode is None:
9✔
972
            continue
9✔
973

974
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
975
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
976
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
977
        process.kill()
×
978
        if child_watcher:
×
979
            child_watcher.remove_child_handler(process.pid)
×
980

981

982
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
983
    """
984
    Shuts down worker processes belonging to this event loop.
985

986
    NOTE: this only works when the event loop was started using asyncio.run() or
987
    anyio.run().
988

989
    """
990
    process: abc.Process
991
    try:
9✔
992
        await sleep(math.inf)
9✔
993
    except asyncio.CancelledError:
9✔
994
        for process in workers:
9✔
995
            if process.returncode is None:
9✔
996
                process.kill()
9✔
997

998
        for process in workers:
9✔
999
            await process.aclose()
9✔
1000

1001

1002
#
1003
# Sockets and networking
1004
#
1005

1006

1007
class StreamProtocol(asyncio.Protocol):
10✔
1008
    read_queue: deque[bytes]
10✔
1009
    read_event: asyncio.Event
10✔
1010
    write_event: asyncio.Event
10✔
1011
    exception: Exception | None = None
10✔
1012

1013
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1014
        self.read_queue = deque()
10✔
1015
        self.read_event = asyncio.Event()
10✔
1016
        self.write_event = asyncio.Event()
10✔
1017
        self.write_event.set()
10✔
1018
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1019

1020
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1021
        if exc:
10✔
1022
            self.exception = BrokenResourceError()
10✔
1023
            self.exception.__cause__ = exc
10✔
1024

1025
        self.read_event.set()
10✔
1026
        self.write_event.set()
10✔
1027

1028
    def data_received(self, data: bytes) -> None:
10✔
1029
        self.read_queue.append(data)
10✔
1030
        self.read_event.set()
10✔
1031

1032
    def eof_received(self) -> bool | None:
10✔
1033
        self.read_event.set()
10✔
1034
        return True
10✔
1035

1036
    def pause_writing(self) -> None:
10✔
1037
        self.write_event = asyncio.Event()
10✔
1038

1039
    def resume_writing(self) -> None:
10✔
1040
        self.write_event.set()
×
1041

1042

1043
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1044
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1045
    read_event: asyncio.Event
10✔
1046
    write_event: asyncio.Event
10✔
1047
    exception: Exception | None = None
10✔
1048

1049
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1050
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1051
        self.read_event = asyncio.Event()
9✔
1052
        self.write_event = asyncio.Event()
9✔
1053
        self.write_event.set()
9✔
1054

1055
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1056
        self.read_event.set()
9✔
1057
        self.write_event.set()
9✔
1058

1059
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1060
        addr = convert_ipv6_sockaddr(addr)
9✔
1061
        self.read_queue.append((data, addr))
9✔
1062
        self.read_event.set()
9✔
1063

1064
    def error_received(self, exc: Exception) -> None:
10✔
1065
        self.exception = exc
×
1066

1067
    def pause_writing(self) -> None:
10✔
1068
        self.write_event.clear()
×
1069

1070
    def resume_writing(self) -> None:
10✔
1071
        self.write_event.set()
×
1072

1073

1074
class SocketStream(abc.SocketStream):
10✔
1075
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1076
        self._transport = transport
10✔
1077
        self._protocol = protocol
10✔
1078
        self._receive_guard = ResourceGuard("reading from")
10✔
1079
        self._send_guard = ResourceGuard("writing to")
10✔
1080
        self._closed = False
10✔
1081

1082
    @property
10✔
1083
    def _raw_socket(self) -> socket.socket:
10✔
1084
        return self._transport.get_extra_info("socket")
10✔
1085

1086
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1087
        with self._receive_guard:
10✔
1088
            await AsyncIOBackend.checkpoint()
10✔
1089

1090
            if (
10✔
1091
                not self._protocol.read_event.is_set()
1092
                and not self._transport.is_closing()
1093
            ):
1094
                self._transport.resume_reading()
10✔
1095
                await self._protocol.read_event.wait()
10✔
1096
                self._transport.pause_reading()
10✔
1097

1098
            try:
10✔
1099
                chunk = self._protocol.read_queue.popleft()
10✔
1100
            except IndexError:
10✔
1101
                if self._closed:
10✔
1102
                    raise ClosedResourceError from None
10✔
1103
                elif self._protocol.exception:
10✔
1104
                    raise self._protocol.exception from None
10✔
1105
                else:
1106
                    raise EndOfStream from None
10✔
1107

1108
            if len(chunk) > max_bytes:
10✔
1109
                # Split the oversized chunk
1110
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1111
                self._protocol.read_queue.appendleft(leftover)
8✔
1112

1113
            # If the read queue is empty, clear the flag so that the next call will
1114
            # block until data is available
1115
            if not self._protocol.read_queue:
10✔
1116
                self._protocol.read_event.clear()
10✔
1117

1118
        return chunk
10✔
1119

1120
    async def send(self, item: bytes) -> None:
10✔
1121
        with self._send_guard:
10✔
1122
            await AsyncIOBackend.checkpoint()
10✔
1123

1124
            if self._closed:
10✔
1125
                raise ClosedResourceError
10✔
1126
            elif self._protocol.exception is not None:
10✔
1127
                raise self._protocol.exception
10✔
1128

1129
            try:
10✔
1130
                self._transport.write(item)
10✔
1131
            except RuntimeError as exc:
×
1132
                if self._transport.is_closing():
×
1133
                    raise BrokenResourceError from exc
×
1134
                else:
1135
                    raise
×
1136

1137
            await self._protocol.write_event.wait()
10✔
1138

1139
    async def send_eof(self) -> None:
10✔
1140
        try:
10✔
1141
            self._transport.write_eof()
10✔
1142
        except OSError:
×
1143
            pass
×
1144

1145
    async def aclose(self) -> None:
10✔
1146
        if not self._transport.is_closing():
10✔
1147
            self._closed = True
10✔
1148
            try:
10✔
1149
                self._transport.write_eof()
10✔
1150
            except OSError:
5✔
1151
                pass
5✔
1152

1153
            self._transport.close()
10✔
1154
            await sleep(0)
10✔
1155
            self._transport.abort()
10✔
1156

1157

1158
class _RawSocketMixin:
10✔
1159
    _receive_future: asyncio.Future | None = None
10✔
1160
    _send_future: asyncio.Future | None = None
10✔
1161
    _closing = False
10✔
1162

1163
    def __init__(self, raw_socket: socket.socket):
10✔
1164
        self.__raw_socket = raw_socket
7✔
1165
        self._receive_guard = ResourceGuard("reading from")
7✔
1166
        self._send_guard = ResourceGuard("writing to")
7✔
1167

1168
    @property
10✔
1169
    def _raw_socket(self) -> socket.socket:
10✔
1170
        return self.__raw_socket
7✔
1171

1172
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1173
        def callback(f: object) -> None:
7✔
1174
            del self._receive_future
7✔
1175
            loop.remove_reader(self.__raw_socket)
7✔
1176

1177
        f = self._receive_future = asyncio.Future()
7✔
1178
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1179
        f.add_done_callback(callback)
7✔
1180
        return f
7✔
1181

1182
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1183
        def callback(f: object) -> None:
7✔
1184
            del self._send_future
7✔
1185
            loop.remove_writer(self.__raw_socket)
7✔
1186

1187
        f = self._send_future = asyncio.Future()
7✔
1188
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1189
        f.add_done_callback(callback)
7✔
1190
        return f
7✔
1191

1192
    async def aclose(self) -> None:
10✔
1193
        if not self._closing:
7✔
1194
            self._closing = True
7✔
1195
            if self.__raw_socket.fileno() != -1:
7✔
1196
                self.__raw_socket.close()
7✔
1197

1198
            if self._receive_future:
7✔
1199
                self._receive_future.set_result(None)
7✔
1200
            if self._send_future:
7✔
1201
                self._send_future.set_result(None)
×
1202

1203

1204
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1205
    async def send_eof(self) -> None:
10✔
1206
        with self._send_guard:
7✔
1207
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1208

1209
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1210
        loop = get_running_loop()
7✔
1211
        await AsyncIOBackend.checkpoint()
7✔
1212
        with self._receive_guard:
7✔
1213
            while True:
4✔
1214
                try:
7✔
1215
                    data = self._raw_socket.recv(max_bytes)
7✔
1216
                except BlockingIOError:
7✔
1217
                    await self._wait_until_readable(loop)
7✔
1218
                except OSError as exc:
7✔
1219
                    if self._closing:
7✔
1220
                        raise ClosedResourceError from None
7✔
1221
                    else:
1222
                        raise BrokenResourceError from exc
1✔
1223
                else:
1224
                    if not data:
7✔
1225
                        raise EndOfStream
7✔
1226

1227
                    return data
7✔
1228

1229
    async def send(self, item: bytes) -> None:
10✔
1230
        loop = get_running_loop()
7✔
1231
        await AsyncIOBackend.checkpoint()
7✔
1232
        with self._send_guard:
7✔
1233
            view = memoryview(item)
7✔
1234
            while view:
7✔
1235
                try:
7✔
1236
                    bytes_sent = self._raw_socket.send(view)
7✔
1237
                except BlockingIOError:
7✔
1238
                    await self._wait_until_writable(loop)
7✔
1239
                except OSError as exc:
7✔
1240
                    if self._closing:
7✔
1241
                        raise ClosedResourceError from None
7✔
1242
                    else:
1243
                        raise BrokenResourceError from exc
1✔
1244
                else:
1245
                    view = view[bytes_sent:]
7✔
1246

1247
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1248
        if not isinstance(msglen, int) or msglen < 0:
7✔
1249
            raise ValueError("msglen must be a non-negative integer")
7✔
1250
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1251
            raise ValueError("maxfds must be a positive integer")
7✔
1252

1253
        loop = get_running_loop()
7✔
1254
        fds = array.array("i")
7✔
1255
        await AsyncIOBackend.checkpoint()
7✔
1256
        with self._receive_guard:
7✔
1257
            while True:
4✔
1258
                try:
7✔
1259
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1260
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1261
                    )
1262
                except BlockingIOError:
7✔
1263
                    await self._wait_until_readable(loop)
7✔
1264
                except OSError as exc:
×
1265
                    if self._closing:
×
1266
                        raise ClosedResourceError from None
×
1267
                    else:
1268
                        raise BrokenResourceError from exc
×
1269
                else:
1270
                    if not message and not ancdata:
7✔
1271
                        raise EndOfStream
×
1272

1273
                    break
4✔
1274

1275
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1276
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
1277
                raise RuntimeError(
×
1278
                    f"Received unexpected ancillary data; message = {message!r}, "
1279
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1280
                )
1281

1282
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1283

1284
        return message, list(fds)
7✔
1285

1286
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1287
        if not message:
7✔
1288
            raise ValueError("message must not be empty")
7✔
1289
        if not fds:
7✔
1290
            raise ValueError("fds must not be empty")
7✔
1291

1292
        loop = get_running_loop()
7✔
1293
        filenos: list[int] = []
7✔
1294
        for fd in fds:
7✔
1295
            if isinstance(fd, int):
7✔
1296
                filenos.append(fd)
×
1297
            elif isinstance(fd, IOBase):
7✔
1298
                filenos.append(fd.fileno())
7✔
1299

1300
        fdarray = array.array("i", filenos)
7✔
1301
        await AsyncIOBackend.checkpoint()
7✔
1302
        with self._send_guard:
7✔
1303
            while True:
4✔
1304
                try:
7✔
1305
                    # The ignore can be removed after mypy picks up
1306
                    # https://github.com/python/typeshed/pull/5545
1307
                    self._raw_socket.sendmsg(
7✔
1308
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1309
                    )
1310
                    break
7✔
1311
                except BlockingIOError:
×
1312
                    await self._wait_until_writable(loop)
×
1313
                except OSError as exc:
×
1314
                    if self._closing:
×
1315
                        raise ClosedResourceError from None
×
1316
                    else:
1317
                        raise BrokenResourceError from exc
×
1318

1319

1320
class TCPSocketListener(abc.SocketListener):
10✔
1321
    _accept_scope: CancelScope | None = None
10✔
1322
    _closed = False
10✔
1323

1324
    def __init__(self, raw_socket: socket.socket):
10✔
1325
        self.__raw_socket = raw_socket
10✔
1326
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1327
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1328

1329
    @property
10✔
1330
    def _raw_socket(self) -> socket.socket:
10✔
1331
        return self.__raw_socket
10✔
1332

1333
    async def accept(self) -> abc.SocketStream:
10✔
1334
        if self._closed:
10✔
1335
            raise ClosedResourceError
10✔
1336

1337
        with self._accept_guard:
10✔
1338
            await AsyncIOBackend.checkpoint()
10✔
1339
            with CancelScope() as self._accept_scope:
10✔
1340
                try:
10✔
1341
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1342
                except asyncio.CancelledError:
9✔
1343
                    # Workaround for https://bugs.python.org/issue41317
1344
                    try:
9✔
1345
                        self._loop.remove_reader(self._raw_socket)
9✔
1346
                    except (ValueError, NotImplementedError):
2✔
1347
                        pass
2✔
1348

1349
                    if self._closed:
9✔
1350
                        raise ClosedResourceError from None
9✔
1351

1352
                    raise
9✔
1353
                finally:
1354
                    self._accept_scope = None
10✔
1355

1356
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1357
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1358
            StreamProtocol, client_sock
1359
        )
1360
        return SocketStream(transport, protocol)
10✔
1361

1362
    async def aclose(self) -> None:
10✔
1363
        if self._closed:
10✔
1364
            return
10✔
1365

1366
        self._closed = True
10✔
1367
        if self._accept_scope:
10✔
1368
            # Workaround for https://bugs.python.org/issue41317
1369
            try:
10✔
1370
                self._loop.remove_reader(self._raw_socket)
10✔
1371
            except (ValueError, NotImplementedError):
2✔
1372
                pass
2✔
1373

1374
            self._accept_scope.cancel()
9✔
1375
            await sleep(0)
9✔
1376

1377
        self._raw_socket.close()
10✔
1378

1379

1380
class UNIXSocketListener(abc.SocketListener):
10✔
1381
    def __init__(self, raw_socket: socket.socket):
10✔
1382
        self.__raw_socket = raw_socket
7✔
1383
        self._loop = get_running_loop()
7✔
1384
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1385
        self._closed = False
7✔
1386

1387
    async def accept(self) -> abc.SocketStream:
10✔
1388
        await AsyncIOBackend.checkpoint()
7✔
1389
        with self._accept_guard:
7✔
1390
            while True:
4✔
1391
                try:
7✔
1392
                    client_sock, _ = self.__raw_socket.accept()
7✔
1393
                    client_sock.setblocking(False)
7✔
1394
                    return UNIXSocketStream(client_sock)
7✔
1395
                except BlockingIOError:
7✔
1396
                    f: asyncio.Future = asyncio.Future()
7✔
1397
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1398
                    f.add_done_callback(
7✔
1399
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1400
                    )
1401
                    await f
7✔
1402
                except OSError as exc:
×
1403
                    if self._closed:
×
1404
                        raise ClosedResourceError from None
×
1405
                    else:
1406
                        raise BrokenResourceError from exc
1✔
1407

1408
    async def aclose(self) -> None:
10✔
1409
        self._closed = True
7✔
1410
        self.__raw_socket.close()
7✔
1411

1412
    @property
10✔
1413
    def _raw_socket(self) -> socket.socket:
10✔
1414
        return self.__raw_socket
7✔
1415

1416

1417
class UDPSocket(abc.UDPSocket):
10✔
1418
    def __init__(
10✔
1419
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1420
    ):
1421
        self._transport = transport
9✔
1422
        self._protocol = protocol
9✔
1423
        self._receive_guard = ResourceGuard("reading from")
9✔
1424
        self._send_guard = ResourceGuard("writing to")
9✔
1425
        self._closed = False
9✔
1426

1427
    @property
10✔
1428
    def _raw_socket(self) -> socket.socket:
10✔
1429
        return self._transport.get_extra_info("socket")
9✔
1430

1431
    async def aclose(self) -> None:
10✔
1432
        if not self._transport.is_closing():
9✔
1433
            self._closed = True
9✔
1434
            self._transport.close()
9✔
1435

1436
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1437
        with self._receive_guard:
9✔
1438
            await AsyncIOBackend.checkpoint()
9✔
1439

1440
            # If the buffer is empty, ask for more data
1441
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1442
                self._protocol.read_event.clear()
9✔
1443
                await self._protocol.read_event.wait()
9✔
1444

1445
            try:
9✔
1446
                return self._protocol.read_queue.popleft()
9✔
1447
            except IndexError:
9✔
1448
                if self._closed:
9✔
1449
                    raise ClosedResourceError from None
9✔
1450
                else:
1451
                    raise BrokenResourceError from None
1✔
1452

1453
    async def send(self, item: UDPPacketType) -> None:
10✔
1454
        with self._send_guard:
9✔
1455
            await AsyncIOBackend.checkpoint()
9✔
1456
            await self._protocol.write_event.wait()
9✔
1457
            if self._closed:
9✔
1458
                raise ClosedResourceError
9✔
1459
            elif self._transport.is_closing():
9✔
1460
                raise BrokenResourceError
×
1461
            else:
1462
                self._transport.sendto(*item)
9✔
1463

1464

1465
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1466
    def __init__(
10✔
1467
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1468
    ):
1469
        self._transport = transport
9✔
1470
        self._protocol = protocol
9✔
1471
        self._receive_guard = ResourceGuard("reading from")
9✔
1472
        self._send_guard = ResourceGuard("writing to")
9✔
1473
        self._closed = False
9✔
1474

1475
    @property
10✔
1476
    def _raw_socket(self) -> socket.socket:
10✔
1477
        return self._transport.get_extra_info("socket")
9✔
1478

1479
    async def aclose(self) -> None:
10✔
1480
        if not self._transport.is_closing():
9✔
1481
            self._closed = True
9✔
1482
            self._transport.close()
9✔
1483

1484
    async def receive(self) -> bytes:
10✔
1485
        with self._receive_guard:
9✔
1486
            await AsyncIOBackend.checkpoint()
9✔
1487

1488
            # If the buffer is empty, ask for more data
1489
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1490
                self._protocol.read_event.clear()
9✔
1491
                await self._protocol.read_event.wait()
9✔
1492

1493
            try:
9✔
1494
                packet = self._protocol.read_queue.popleft()
9✔
1495
            except IndexError:
9✔
1496
                if self._closed:
9✔
1497
                    raise ClosedResourceError from None
9✔
1498
                else:
1499
                    raise BrokenResourceError from None
×
1500

1501
            return packet[0]
9✔
1502

1503
    async def send(self, item: bytes) -> None:
10✔
1504
        with self._send_guard:
9✔
1505
            await AsyncIOBackend.checkpoint()
9✔
1506
            await self._protocol.write_event.wait()
9✔
1507
            if self._closed:
9✔
1508
                raise ClosedResourceError
9✔
1509
            elif self._transport.is_closing():
9✔
1510
                raise BrokenResourceError
×
1511
            else:
1512
                self._transport.sendto(item)
9✔
1513

1514

1515
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1516
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1517
        loop = get_running_loop()
7✔
1518
        await AsyncIOBackend.checkpoint()
7✔
1519
        with self._receive_guard:
7✔
1520
            while True:
4✔
1521
                try:
7✔
1522
                    data = self._raw_socket.recvfrom(65536)
7✔
1523
                except BlockingIOError:
7✔
1524
                    await self._wait_until_readable(loop)
7✔
1525
                except OSError as exc:
7✔
1526
                    if self._closing:
7✔
1527
                        raise ClosedResourceError from None
7✔
1528
                    else:
1529
                        raise BrokenResourceError from exc
1✔
1530
                else:
1531
                    return data
7✔
1532

1533
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1534
        loop = get_running_loop()
7✔
1535
        await AsyncIOBackend.checkpoint()
7✔
1536
        with self._send_guard:
7✔
1537
            while True:
4✔
1538
                try:
7✔
1539
                    self._raw_socket.sendto(*item)
7✔
1540
                except BlockingIOError:
7✔
1541
                    await self._wait_until_writable(loop)
×
1542
                except OSError as exc:
7✔
1543
                    if self._closing:
7✔
1544
                        raise ClosedResourceError from None
7✔
1545
                    else:
1546
                        raise BrokenResourceError from exc
1✔
1547
                else:
1548
                    return
7✔
1549

1550

1551
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1552
    async def receive(self) -> bytes:
10✔
1553
        loop = get_running_loop()
7✔
1554
        await AsyncIOBackend.checkpoint()
7✔
1555
        with self._receive_guard:
7✔
1556
            while True:
4✔
1557
                try:
7✔
1558
                    data = self._raw_socket.recv(65536)
7✔
1559
                except BlockingIOError:
7✔
1560
                    await self._wait_until_readable(loop)
7✔
1561
                except OSError as exc:
7✔
1562
                    if self._closing:
7✔
1563
                        raise ClosedResourceError from None
7✔
1564
                    else:
1565
                        raise BrokenResourceError from exc
1✔
1566
                else:
1567
                    return data
7✔
1568

1569
    async def send(self, item: bytes) -> None:
10✔
1570
        loop = get_running_loop()
7✔
1571
        await AsyncIOBackend.checkpoint()
7✔
1572
        with self._send_guard:
7✔
1573
            while True:
4✔
1574
                try:
7✔
1575
                    self._raw_socket.send(item)
7✔
1576
                except BlockingIOError:
7✔
1577
                    await self._wait_until_writable(loop)
×
1578
                except OSError as exc:
7✔
1579
                    if self._closing:
7✔
1580
                        raise ClosedResourceError from None
7✔
1581
                    else:
1582
                        raise BrokenResourceError from exc
1✔
1583
                else:
1584
                    return
7✔
1585

1586

1587
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1588
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1589

1590

1591
#
1592
# Synchronization
1593
#
1594

1595

1596
class Event(BaseEvent):
10✔
1597
    def __new__(cls) -> Event:
10✔
1598
        return object.__new__(cls)
10✔
1599

1600
    def __init__(self) -> None:
10✔
1601
        self._event = asyncio.Event()
10✔
1602

1603
    def set(self) -> None:
10✔
1604
        self._event.set()
10✔
1605

1606
    def is_set(self) -> bool:
10✔
1607
        return self._event.is_set()
10✔
1608

1609
    async def wait(self) -> None:
10✔
1610
        if self.is_set():
10✔
1611
            await AsyncIOBackend.checkpoint()
10✔
1612
        else:
1613
            await self._event.wait()
10✔
1614

1615
    def statistics(self) -> EventStatistics:
10✔
1616
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1617

1618

1619
class CapacityLimiter(BaseCapacityLimiter):
10✔
1620
    _total_tokens: float = 0
10✔
1621

1622
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1623
        return object.__new__(cls)
10✔
1624

1625
    def __init__(self, total_tokens: float):
10✔
1626
        self._borrowers: set[Any] = set()
10✔
1627
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1628
        self.total_tokens = total_tokens
10✔
1629

1630
    async def __aenter__(self) -> None:
10✔
1631
        await self.acquire()
10✔
1632

1633
    async def __aexit__(
10✔
1634
        self,
1635
        exc_type: type[BaseException] | None,
1636
        exc_val: BaseException | None,
1637
        exc_tb: TracebackType | None,
1638
    ) -> None:
1639
        self.release()
10✔
1640

1641
    @property
10✔
1642
    def total_tokens(self) -> float:
10✔
1643
        return self._total_tokens
9✔
1644

1645
    @total_tokens.setter
10✔
1646
    def total_tokens(self, value: float) -> None:
10✔
1647
        if not isinstance(value, int) and not math.isinf(value):
10✔
1648
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1649
        if value < 1:
10✔
1650
            raise ValueError("total_tokens must be >= 1")
9✔
1651

1652
        old_value = self._total_tokens
10✔
1653
        self._total_tokens = value
10✔
1654
        events = []
10✔
1655
        for event in self._wait_queue.values():
10✔
1656
            if value <= old_value:
9✔
1657
                break
×
1658

1659
            if not event.is_set():
9✔
1660
                events.append(event)
9✔
1661
                old_value += 1
9✔
1662

1663
        for event in events:
10✔
1664
            event.set()
9✔
1665

1666
    @property
10✔
1667
    def borrowed_tokens(self) -> int:
10✔
1668
        return len(self._borrowers)
9✔
1669

1670
    @property
10✔
1671
    def available_tokens(self) -> float:
10✔
1672
        return self._total_tokens - len(self._borrowers)
9✔
1673

1674
    def acquire_nowait(self) -> None:
10✔
1675
        self.acquire_on_behalf_of_nowait(current_task())
×
1676

1677
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1678
        if borrower in self._borrowers:
10✔
1679
            raise RuntimeError(
9✔
1680
                "this borrower is already holding one of this CapacityLimiter's "
1681
                "tokens"
1682
            )
1683

1684
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1685
            raise WouldBlock
9✔
1686

1687
        self._borrowers.add(borrower)
10✔
1688

1689
    async def acquire(self) -> None:
10✔
1690
        return await self.acquire_on_behalf_of(current_task())
10✔
1691

1692
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1693
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1694
        try:
10✔
1695
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1696
        except WouldBlock:
9✔
1697
            event = asyncio.Event()
9✔
1698
            self._wait_queue[borrower] = event
9✔
1699
            try:
9✔
1700
                await event.wait()
9✔
1701
            except BaseException:
×
1702
                self._wait_queue.pop(borrower, None)
×
1703
                raise
×
1704

1705
            self._borrowers.add(borrower)
9✔
1706
        else:
1707
            try:
10✔
1708
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1709
            except BaseException:
9✔
1710
                self.release()
9✔
1711
                raise
9✔
1712

1713
    def release(self) -> None:
10✔
1714
        self.release_on_behalf_of(current_task())
10✔
1715

1716
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1717
        try:
10✔
1718
            self._borrowers.remove(borrower)
10✔
1719
        except KeyError:
9✔
1720
            raise RuntimeError(
9✔
1721
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1722
            ) from None
1723

1724
        # Notify the next task in line if this limiter has free capacity now
1725
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1726
            event = self._wait_queue.popitem(last=False)[1]
9✔
1727
            event.set()
9✔
1728

1729
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1730
        return CapacityLimiterStatistics(
9✔
1731
            self.borrowed_tokens,
1732
            self.total_tokens,
1733
            tuple(self._borrowers),
1734
            len(self._wait_queue),
1735
        )
1736

1737

1738
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1739

1740

1741
#
1742
# Operating system signals
1743
#
1744

1745

1746
class _SignalReceiver:
10✔
1747
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1748
        self._signals = signals
8✔
1749
        self._loop = get_running_loop()
8✔
1750
        self._signal_queue: deque[Signals] = deque()
8✔
1751
        self._future: asyncio.Future = asyncio.Future()
8✔
1752
        self._handled_signals: set[Signals] = set()
8✔
1753

1754
    def _deliver(self, signum: Signals) -> None:
10✔
1755
        self._signal_queue.append(signum)
8✔
1756
        if not self._future.done():
8✔
1757
            self._future.set_result(None)
8✔
1758

1759
    def __enter__(self) -> _SignalReceiver:
10✔
1760
        for sig in set(self._signals):
8✔
1761
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1762
            self._handled_signals.add(sig)
8✔
1763

1764
        return self
8✔
1765

1766
    def __exit__(
10✔
1767
        self,
1768
        exc_type: type[BaseException] | None,
1769
        exc_val: BaseException | None,
1770
        exc_tb: TracebackType | None,
1771
    ) -> bool | None:
1772
        for sig in self._handled_signals:
8✔
1773
            self._loop.remove_signal_handler(sig)
8✔
1774
        return None
8✔
1775

1776
    def __aiter__(self) -> _SignalReceiver:
10✔
1777
        return self
8✔
1778

1779
    async def __anext__(self) -> Signals:
10✔
1780
        await AsyncIOBackend.checkpoint()
8✔
1781
        if not self._signal_queue:
8✔
1782
            self._future = asyncio.Future()
×
1783
            await self._future
×
1784

1785
        return self._signal_queue.popleft()
8✔
1786

1787

1788
#
1789
# Testing and debugging
1790
#
1791

1792

1793
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1794
    task_state = _task_states.get(task)
10✔
1795
    if task_state is None:
10✔
1796
        parent_id = None
10✔
1797
    else:
1798
        parent_id = task_state.parent_id
10✔
1799

1800
    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
10✔
1801

1802

1803
class TestRunner(abc.TestRunner):
10✔
1804
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1805

1806
    def __init__(
10✔
1807
        self,
1808
        *,
1809
        debug: bool | None = None,
1810
        use_uvloop: bool = False,
1811
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
1812
    ) -> None:
1813
        if use_uvloop and loop_factory is None:
10✔
1814
            import uvloop
×
1815

1816
            loop_factory = uvloop.new_event_loop
×
1817

1818
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
1819
        self._exceptions: list[BaseException] = []
10✔
1820
        self._runner_task: asyncio.Task | None = None
10✔
1821

1822
    def __enter__(self) -> TestRunner:
10✔
1823
        self._runner.__enter__()
10✔
1824
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
1825
        return self
10✔
1826

1827
    def __exit__(
10✔
1828
        self,
1829
        exc_type: type[BaseException] | None,
1830
        exc_val: BaseException | None,
1831
        exc_tb: TracebackType | None,
1832
    ) -> None:
1833
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
1834

1835
    def get_loop(self) -> AbstractEventLoop:
10✔
1836
        return self._runner.get_loop()
10✔
1837

1838
    def _exception_handler(
10✔
1839
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1840
    ) -> None:
1841
        if isinstance(context.get("exception"), Exception):
10✔
1842
            self._exceptions.append(context["exception"])
10✔
1843
        else:
1844
            loop.default_exception_handler(context)
10✔
1845

1846
    def _raise_async_exceptions(self) -> None:
10✔
1847
        # Re-raise any exceptions raised in asynchronous callbacks
1848
        if self._exceptions:
10✔
1849
            exceptions, self._exceptions = self._exceptions, []
10✔
1850
            if len(exceptions) == 1:
10✔
1851
                raise exceptions[0]
10✔
1852
            elif exceptions:
×
1853
                raise BaseExceptionGroup(
×
1854
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1855
                )
1856

1857
    @staticmethod
10✔
1858
    async def _run_tests_and_fixtures(
10✔
1859
        receive_stream: MemoryObjectReceiveStream[
1860
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
1861
        ],
1862
    ) -> None:
1863
        with receive_stream:
10✔
1864
            async for coro, future in receive_stream:
10✔
1865
                try:
10✔
1866
                    retval = await coro
10✔
1867
                except BaseException as exc:
10✔
1868
                    if not future.cancelled():
10✔
1869
                        future.set_exception(exc)
10✔
1870
                else:
1871
                    if not future.cancelled():
10✔
1872
                        future.set_result(retval)
10✔
1873

1874
    async def _call_in_runner_task(
10✔
1875
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1876
    ) -> T_Retval:
1877
        if not self._runner_task:
10✔
1878
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
1879
                Tuple[Awaitable[Any], asyncio.Future]
1880
            ](1)
1881
            self._runner_task = self.get_loop().create_task(
10✔
1882
                self._run_tests_and_fixtures(receive_stream)
1883
            )
1884

1885
        coro = func(*args, **kwargs)
10✔
1886
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
1887
        self._send_stream.send_nowait((coro, future))
10✔
1888
        return await future
10✔
1889

1890
    def run_asyncgen_fixture(
10✔
1891
        self,
1892
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1893
        kwargs: dict[str, Any],
1894
    ) -> Iterable[T_Retval]:
1895
        asyncgen = fixture_func(**kwargs)
10✔
1896
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
1897
            self._call_in_runner_task(asyncgen.asend, None)
1898
        )
1899
        self._raise_async_exceptions()
10✔
1900

1901
        yield fixturevalue
10✔
1902

1903
        try:
10✔
1904
            self.get_loop().run_until_complete(
10✔
1905
                self._call_in_runner_task(asyncgen.asend, None)
1906
            )
1907
        except StopAsyncIteration:
10✔
1908
            self._raise_async_exceptions()
10✔
1909
        else:
1910
            self.get_loop().run_until_complete(asyncgen.aclose())
×
1911
            raise RuntimeError("Async generator fixture did not stop")
×
1912

1913
    def run_fixture(
10✔
1914
        self,
1915
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1916
        kwargs: dict[str, Any],
1917
    ) -> T_Retval:
1918
        retval = self.get_loop().run_until_complete(
10✔
1919
            self._call_in_runner_task(fixture_func, **kwargs)
1920
        )
1921
        self._raise_async_exceptions()
10✔
1922
        return retval
10✔
1923

1924
    def run_test(
10✔
1925
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1926
    ) -> None:
1927
        try:
10✔
1928
            self.get_loop().run_until_complete(
10✔
1929
                self._call_in_runner_task(test_func, **kwargs)
1930
            )
1931
        except Exception as exc:
10✔
1932
            self._exceptions.append(exc)
9✔
1933

1934
        self._raise_async_exceptions()
10✔
1935

1936

1937
class AsyncIOBackend(AsyncBackend):
10✔
1938
    @classmethod
10✔
1939
    def run(
10✔
1940
        cls,
1941
        func: Callable[..., Awaitable[T_Retval]],
1942
        args: tuple,
1943
        kwargs: dict[str, Any],
1944
        options: dict[str, Any],
1945
    ) -> T_Retval:
1946
        @wraps(func)
10✔
1947
        async def wrapper() -> T_Retval:
10✔
1948
            task = cast(asyncio.Task, current_task())
10✔
1949
            task.set_name(get_callable_name(func))
10✔
1950
            _task_states[task] = TaskState(None, None)
10✔
1951

1952
            try:
10✔
1953
                return await func(*args)
10✔
1954
            finally:
1955
                del _task_states[task]
10✔
1956

1957
        debug = options.get("debug", False)
10✔
1958
        options.get("loop_factory", None)
10✔
1959
        options.get("use_uvloop", False)
10✔
1960
        return native_run(wrapper(), debug=debug)
10✔
1961

1962
    @classmethod
10✔
1963
    def current_token(cls) -> object:
10✔
1964
        return get_running_loop()
10✔
1965

1966
    @classmethod
10✔
1967
    def current_time(cls) -> float:
10✔
1968
        return get_running_loop().time()
10✔
1969

1970
    @classmethod
10✔
1971
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
1972
        return CancelledError
10✔
1973

1974
    @classmethod
10✔
1975
    async def checkpoint(cls) -> None:
10✔
1976
        await sleep(0)
10✔
1977

1978
    @classmethod
10✔
1979
    async def checkpoint_if_cancelled(cls) -> None:
10✔
1980
        task = current_task()
10✔
1981
        if task is None:
10✔
1982
            return
×
1983

1984
        try:
10✔
1985
            cancel_scope = _task_states[task].cancel_scope
10✔
1986
        except KeyError:
10✔
1987
            return
10✔
1988

1989
        while cancel_scope:
10✔
1990
            if cancel_scope.cancel_called:
10✔
1991
                await sleep(0)
10✔
1992
            elif cancel_scope.shield:
10✔
1993
                break
9✔
1994
            else:
1995
                cancel_scope = cancel_scope._parent_scope
10✔
1996

1997
    @classmethod
10✔
1998
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
1999
        with CancelScope(shield=True):
10✔
2000
            await sleep(0)
10✔
2001

2002
    @classmethod
10✔
2003
    async def sleep(cls, delay: float) -> None:
10✔
2004
        await sleep(delay)
10✔
2005

2006
    @classmethod
10✔
2007
    def create_cancel_scope(
10✔
2008
        cls, *, deadline: float = math.inf, shield: bool = False
2009
    ) -> CancelScope:
2010
        return CancelScope(deadline=deadline, shield=shield)
10✔
2011

2012
    @classmethod
10✔
2013
    def current_effective_deadline(cls) -> float:
10✔
2014
        try:
9✔
2015
            cancel_scope = _task_states[
9✔
2016
                current_task()  # type: ignore[index]
2017
            ].cancel_scope
2018
        except KeyError:
×
2019
            return math.inf
×
2020

2021
        deadline = math.inf
9✔
2022
        while cancel_scope:
9✔
2023
            deadline = min(deadline, cancel_scope.deadline)
9✔
2024
            if cancel_scope._cancel_called:
9✔
2025
                deadline = -math.inf
9✔
2026
                break
9✔
2027
            elif cancel_scope.shield:
9✔
2028
                break
9✔
2029
            else:
2030
                cancel_scope = cancel_scope._parent_scope
9✔
2031

2032
        return deadline
9✔
2033

2034
    @classmethod
10✔
2035
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2036
        return TaskGroup()
10✔
2037

2038
    @classmethod
10✔
2039
    def create_event(cls) -> abc.Event:
10✔
2040
        return Event()
10✔
2041

2042
    @classmethod
10✔
2043
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2044
        return CapacityLimiter(total_tokens)
9✔
2045

2046
    @classmethod
10✔
2047
    async def run_sync_in_worker_thread(
10✔
2048
        cls,
2049
        func: Callable[..., T_Retval],
2050
        args: tuple[Any, ...],
2051
        cancellable: bool = False,
2052
        limiter: abc.CapacityLimiter | None = None,
2053
    ) -> T_Retval:
2054
        await cls.checkpoint()
10✔
2055

2056
        # If this is the first run in this event loop thread, set up the necessary
2057
        # variables
2058
        try:
10✔
2059
            idle_workers = _threadpool_idle_workers.get()
10✔
2060
            workers = _threadpool_workers.get()
10✔
2061
        except LookupError:
10✔
2062
            idle_workers = deque()
10✔
2063
            workers = set()
10✔
2064
            _threadpool_idle_workers.set(idle_workers)
10✔
2065
            _threadpool_workers.set(workers)
10✔
2066

2067
        async with limiter or cls.current_default_thread_limiter():
10✔
2068
            with CancelScope(shield=not cancellable):
10✔
2069
                future: asyncio.Future = asyncio.Future()
10✔
2070
                root_task = find_root_task()
10✔
2071
                if not idle_workers:
10✔
2072
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2073
                    worker.start()
10✔
2074
                    workers.add(worker)
10✔
2075
                    root_task.add_done_callback(worker.stop)
10✔
2076
                else:
2077
                    worker = idle_workers.pop()
10✔
2078

2079
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2080
                    # seconds or longer
2081
                    now = cls.current_time()
10✔
2082
                    while idle_workers:
10✔
2083
                        if (
9✔
2084
                            now - idle_workers[0].idle_since
2085
                            < WorkerThread.MAX_IDLE_TIME
2086
                        ):
2087
                            break
9✔
2088

2089
                        expired_worker = idle_workers.popleft()
×
2090
                        expired_worker.root_task.remove_done_callback(
×
2091
                            expired_worker.stop
2092
                        )
2093
                        expired_worker.stop()
×
2094

2095
                context = copy_context()
10✔
2096
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2097
                worker.queue.put_nowait((context, func, args, future))
10✔
2098
                return await future
10✔
2099

2100
    @classmethod
10✔
2101
    def run_async_from_thread(
10✔
2102
        cls,
2103
        func: Callable[..., Awaitable[T_Retval]],
2104
        args: tuple[Any, ...],
2105
        token: object,
2106
    ) -> T_Retval:
2107
        loop = cast(AbstractEventLoop, token)
10✔
2108
        context = copy_context()
10✔
2109
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2110
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2111
            asyncio.run_coroutine_threadsafe, func(*args), loop
2112
        )
2113
        return f.result()
10✔
2114

2115
    @classmethod
10✔
2116
    def run_sync_from_thread(
10✔
2117
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2118
    ) -> T_Retval:
2119
        @wraps(func)
10✔
2120
        def wrapper() -> None:
10✔
2121
            try:
10✔
2122
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2123
                f.set_result(func(*args))
10✔
2124
            except BaseException as exc:
10✔
2125
                f.set_exception(exc)
10✔
2126
                if not isinstance(exc, Exception):
10✔
2127
                    raise
×
2128

2129
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2130
        loop = cast(AbstractEventLoop, token)
10✔
2131
        loop.call_soon_threadsafe(wrapper)
10✔
2132
        return f.result()
10✔
2133

2134
    @classmethod
10✔
2135
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2136
        return BlockingPortal()
10✔
2137

2138
    @classmethod
10✔
2139
    async def open_process(
10✔
2140
        cls,
2141
        command: str | bytes | Sequence[str | bytes],
2142
        *,
2143
        shell: bool,
2144
        stdin: int | IO[Any] | None,
2145
        stdout: int | IO[Any] | None,
2146
        stderr: int | IO[Any] | None,
2147
        cwd: str | bytes | PathLike | None = None,
2148
        env: Mapping[str, str] | None = None,
2149
        start_new_session: bool = False,
2150
    ) -> Process:
2151
        await cls.checkpoint()
9✔
2152
        if shell:
9✔
2153
            process = await asyncio.create_subprocess_shell(
9✔
2154
                cast("str | bytes", command),
2155
                stdin=stdin,
2156
                stdout=stdout,
2157
                stderr=stderr,
2158
                cwd=cwd,
2159
                env=env,
2160
                start_new_session=start_new_session,
2161
            )
2162
        else:
2163
            process = await asyncio.create_subprocess_exec(
9✔
2164
                *command,
2165
                stdin=stdin,
2166
                stdout=stdout,
2167
                stderr=stderr,
2168
                cwd=cwd,
2169
                env=env,
2170
                start_new_session=start_new_session,
2171
            )
2172

2173
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2174
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2175
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2176
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2177

2178
    @classmethod
10✔
2179
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2180
        create_task(
9✔
2181
            _shutdown_process_pool_on_exit(workers),
2182
            name="AnyIO process pool shutdown task",
2183
        )
2184
        find_root_task().add_done_callback(
9✔
2185
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2186
        )
2187

2188
    @classmethod
10✔
2189
    async def connect_tcp(
10✔
2190
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2191
    ) -> abc.SocketStream:
2192
        transport, protocol = cast(
10✔
2193
            Tuple[asyncio.Transport, StreamProtocol],
2194
            await get_running_loop().create_connection(
2195
                StreamProtocol, host, port, local_addr=local_address
2196
            ),
2197
        )
2198
        transport.pause_reading()
10✔
2199
        return SocketStream(transport, protocol)
10✔
2200

2201
    @classmethod
10✔
2202
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
10✔
2203
        await cls.checkpoint()
7✔
2204
        loop = get_running_loop()
7✔
2205
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2206
        raw_socket.setblocking(False)
7✔
2207
        while True:
4✔
2208
            try:
7✔
2209
                raw_socket.connect(path)
7✔
2210
            except BlockingIOError:
7✔
2211
                f: asyncio.Future = asyncio.Future()
×
2212
                loop.add_writer(raw_socket, f.set_result, None)
×
2213
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2214
                await f
×
2215
            except BaseException:
7✔
2216
                raw_socket.close()
7✔
2217
                raise
7✔
2218
            else:
2219
                return UNIXSocketStream(raw_socket)
7✔
2220

2221
    @classmethod
10✔
2222
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2223
        return TCPSocketListener(sock)
10✔
2224

2225
    @classmethod
10✔
2226
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2227
        return UNIXSocketListener(sock)
7✔
2228

2229
    @classmethod
10✔
2230
    async def create_udp_socket(
10✔
2231
        cls,
2232
        family: AddressFamily,
2233
        local_address: IPSockAddrType | None,
2234
        remote_address: IPSockAddrType | None,
2235
        reuse_port: bool,
2236
    ) -> UDPSocket | ConnectedUDPSocket:
2237
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2238
            DatagramProtocol,
2239
            local_addr=local_address,
2240
            remote_addr=remote_address,
2241
            family=family,
2242
            reuse_port=reuse_port,
2243
        )
2244
        if protocol.exception:
9✔
2245
            transport.close()
×
2246
            raise protocol.exception
×
2247

2248
        if not remote_address:
9✔
2249
            return UDPSocket(transport, protocol)
9✔
2250
        else:
2251
            return ConnectedUDPSocket(transport, protocol)
9✔
2252

2253
    @classmethod
10✔
2254
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2255
        cls, raw_socket: socket.socket, remote_path: str | None
2256
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2257
        await cls.checkpoint()
7✔
2258
        loop = get_running_loop()
7✔
2259

2260
        if remote_path:
7✔
2261
            while True:
4✔
2262
                try:
7✔
2263
                    raw_socket.connect(remote_path)
7✔
2264
                except BlockingIOError:
×
2265
                    f: asyncio.Future = asyncio.Future()
×
2266
                    loop.add_writer(raw_socket, f.set_result, None)
×
2267
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2268
                    await f
×
2269
                except BaseException:
×
2270
                    raw_socket.close()
×
2271
                    raise
×
2272
                else:
2273
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2274
        else:
2275
            return UNIXDatagramSocket(raw_socket)
7✔
2276

2277
    @classmethod
10✔
2278
    async def getaddrinfo(
10✔
2279
        cls,
2280
        host: bytes | str | None,
2281
        port: str | int | None,
2282
        *,
2283
        family: int | AddressFamily = 0,
2284
        type: int | SocketKind = 0,
2285
        proto: int = 0,
2286
        flags: int = 0,
2287
    ) -> list[
2288
        tuple[
2289
            AddressFamily,
2290
            SocketKind,
2291
            int,
2292
            str,
2293
            tuple[str, int] | tuple[str, int, int, int],
2294
        ]
2295
    ]:
2296
        return await get_running_loop().getaddrinfo(
10✔
2297
            host, port, family=family, type=type, proto=proto, flags=flags
2298
        )
2299

2300
    @classmethod
10✔
2301
    async def getnameinfo(
10✔
2302
        cls, sockaddr: IPSockAddrType, flags: int = 0
2303
    ) -> tuple[str, str]:
2304
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2305

2306
    @classmethod
10✔
2307
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2308
        await cls.checkpoint()
×
2309
        try:
×
2310
            read_events = _read_events.get()
×
2311
        except LookupError:
×
2312
            read_events = {}
×
2313
            _read_events.set(read_events)
×
2314

2315
        if read_events.get(sock):
×
2316
            raise BusyResourceError("reading from") from None
×
2317

2318
        loop = get_running_loop()
×
2319
        event = read_events[sock] = asyncio.Event()
×
2320
        loop.add_reader(sock, event.set)
×
2321
        try:
×
2322
            await event.wait()
×
2323
        finally:
2324
            if read_events.pop(sock, None) is not None:
×
2325
                loop.remove_reader(sock)
×
2326
                readable = True
×
2327
            else:
2328
                readable = False
×
2329

2330
        if not readable:
×
2331
            raise ClosedResourceError
×
2332

2333
    @classmethod
10✔
2334
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2335
        await cls.checkpoint()
×
2336
        try:
×
2337
            write_events = _write_events.get()
×
2338
        except LookupError:
×
2339
            write_events = {}
×
2340
            _write_events.set(write_events)
×
2341

2342
        if write_events.get(sock):
×
2343
            raise BusyResourceError("writing to") from None
×
2344

2345
        loop = get_running_loop()
×
2346
        event = write_events[sock] = asyncio.Event()
×
2347
        loop.add_writer(sock.fileno(), event.set)
×
2348
        try:
×
2349
            await event.wait()
×
2350
        finally:
2351
            if write_events.pop(sock, None) is not None:
×
2352
                loop.remove_writer(sock)
×
2353
                writable = True
×
2354
            else:
2355
                writable = False
×
2356

2357
        if not writable:
×
2358
            raise ClosedResourceError
×
2359

2360
    @classmethod
10✔
2361
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2362
        try:
10✔
2363
            return _default_thread_limiter.get()
10✔
2364
        except LookupError:
10✔
2365
            limiter = CapacityLimiter(40)
10✔
2366
            _default_thread_limiter.set(limiter)
10✔
2367
            return limiter
10✔
2368

2369
    @classmethod
10✔
2370
    def open_signal_receiver(
10✔
2371
        cls, *signals: Signals
2372
    ) -> ContextManager[AsyncIterator[Signals]]:
2373
        return _SignalReceiver(signals)
8✔
2374

2375
    @classmethod
10✔
2376
    def get_current_task(cls) -> TaskInfo:
10✔
2377
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2378

2379
    @classmethod
10✔
2380
    def get_running_tasks(cls) -> list[TaskInfo]:
10✔
2381
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2382

2383
    @classmethod
10✔
2384
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2385
        await cls.checkpoint()
10✔
2386
        this_task = current_task()
10✔
2387
        while True:
6✔
2388
            for task in all_tasks():
10✔
2389
                if task is this_task:
10✔
2390
                    continue
10✔
2391

2392
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2393
                if waiter is None or waiter.done():
10✔
2394
                    await sleep(0.1)
10✔
2395
                    break
10✔
2396
            else:
2397
                return
10✔
2398

2399
    @classmethod
10✔
2400
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2401
        return TestRunner(**options)
10✔
2402

2403

2404
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc