• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 7187497562

12 Dec 2023 09:29PM UTC coverage: 90.883% (+0.04%) from 90.846%
7187497562

Pull #651

github

web-flow
Merge 7c5bf4f02 into f1d077042
Pull Request #651: Enabled Event and CapacityLimiter to be instantiated outside an event loop

64 of 66 new or added lines in 1 file covered. (96.97%)

2 existing lines in 1 file now uncovered.

4396 of 4837 relevant lines covered (90.88%)

8.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.54
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
20
from collections import OrderedDict, deque
10✔
21
from collections.abc import AsyncIterator, Generator, Iterable
10✔
22
from concurrent.futures import Future
10✔
23
from contextlib import suppress
10✔
24
from contextvars import Context, copy_context
10✔
25
from dataclasses import dataclass
10✔
26
from functools import partial, wraps
10✔
27
from inspect import (
10✔
28
    CORO_RUNNING,
29
    CORO_SUSPENDED,
30
    getcoroutinestate,
31
    iscoroutine,
32
)
33
from io import IOBase
10✔
34
from os import PathLike
10✔
35
from queue import Queue
10✔
36
from signal import Signals
10✔
37
from socket import AddressFamily, SocketKind
10✔
38
from threading import Thread
10✔
39
from types import TracebackType
10✔
40
from typing import (
10✔
41
    IO,
42
    Any,
43
    AsyncGenerator,
44
    Awaitable,
45
    Callable,
46
    Collection,
47
    ContextManager,
48
    Coroutine,
49
    Mapping,
50
    Optional,
51
    Sequence,
52
    Tuple,
53
    TypeVar,
54
    cast,
55
)
56
from weakref import WeakKeyDictionary
10✔
57

58
import sniffio
10✔
59

60
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
61
from .._core._eventloop import claim_worker_thread, threadlocals
10✔
62
from .._core._exceptions import (
10✔
63
    BrokenResourceError,
64
    BusyResourceError,
65
    ClosedResourceError,
66
    EndOfStream,
67
    WouldBlock,
68
)
69
from .._core._sockets import convert_ipv6_sockaddr
10✔
70
from .._core._streams import create_memory_object_stream
10✔
71
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
72
from .._core._synchronization import Event as BaseEvent
10✔
73
from .._core._synchronization import ResourceGuard
10✔
74
from .._core._tasks import CancelScope as BaseCancelScope
10✔
75
from ..abc import (
10✔
76
    AsyncBackend,
77
    IPSockAddrType,
78
    SocketListener,
79
    UDPPacketType,
80
    UNIXDatagramPacketType,
81
)
82
from ..lowlevel import RunVar
10✔
83
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
84

85
if sys.version_info >= (3, 11):
10✔
86
    from asyncio import Runner
5✔
87
else:
88
    import contextvars
6✔
89
    import enum
6✔
90
    import signal
6✔
91
    from asyncio import coroutines, events, exceptions, tasks
6✔
92

93
    from exceptiongroup import BaseExceptionGroup
6✔
94

95
    class _State(enum.Enum):
6✔
96
        CREATED = "created"
6✔
97
        INITIALIZED = "initialized"
6✔
98
        CLOSED = "closed"
6✔
99

100
    class Runner:
6✔
101
        # Copied from CPython 3.11
102
        def __init__(
6✔
103
            self,
104
            *,
105
            debug: bool | None = None,
106
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
107
        ):
108
            self._state = _State.CREATED
6✔
109
            self._debug = debug
6✔
110
            self._loop_factory = loop_factory
6✔
111
            self._loop: AbstractEventLoop | None = None
6✔
112
            self._context = None
6✔
113
            self._interrupt_count = 0
6✔
114
            self._set_event_loop = False
6✔
115

116
        def __enter__(self) -> Runner:
6✔
117
            self._lazy_init()
6✔
118
            return self
6✔
119

120
        def __exit__(
6✔
121
            self,
122
            exc_type: type[BaseException],
123
            exc_val: BaseException,
124
            exc_tb: TracebackType,
125
        ) -> None:
126
            self.close()
6✔
127

128
        def close(self) -> None:
6✔
129
            """Shutdown and close event loop."""
130
            if self._state is not _State.INITIALIZED:
6✔
131
                return
×
132
            try:
6✔
133
                loop = self._loop
6✔
134
                _cancel_all_tasks(loop)
6✔
135
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
136
                if hasattr(loop, "shutdown_default_executor"):
6✔
137
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
138
                else:
139
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
140
            finally:
141
                if self._set_event_loop:
6✔
142
                    events.set_event_loop(None)
6✔
143
                loop.close()
6✔
144
                self._loop = None
6✔
145
                self._state = _State.CLOSED
6✔
146

147
        def get_loop(self) -> AbstractEventLoop:
6✔
148
            """Return embedded event loop."""
149
            self._lazy_init()
6✔
150
            return self._loop
6✔
151

152
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
153
            """Run a coroutine inside the embedded event loop."""
154
            if not coroutines.iscoroutine(coro):
6✔
155
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
156

157
            if events._get_running_loop() is not None:
6✔
158
                # fail fast with short traceback
159
                raise RuntimeError(
×
160
                    "Runner.run() cannot be called from a running event loop"
161
                )
162

163
            self._lazy_init()
6✔
164

165
            if context is None:
6✔
166
                context = self._context
6✔
167
            task = context.run(self._loop.create_task, coro)
6✔
168

169
            if (
6✔
170
                threading.current_thread() is threading.main_thread()
171
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
172
            ):
173
                sigint_handler = partial(self._on_sigint, main_task=task)
6✔
174
                try:
6✔
175
                    signal.signal(signal.SIGINT, sigint_handler)
6✔
176
                except ValueError:
×
177
                    # `signal.signal` may throw if `threading.main_thread` does
178
                    # not support signals (e.g. embedded interpreter with signals
179
                    # not registered - see gh-91880)
180
                    sigint_handler = None
×
181
            else:
182
                sigint_handler = None
6✔
183

184
            self._interrupt_count = 0
6✔
185
            try:
6✔
186
                return self._loop.run_until_complete(task)
6✔
187
            except exceptions.CancelledError:
6✔
188
                if self._interrupt_count > 0:
×
189
                    uncancel = getattr(task, "uncancel", None)
×
190
                    if uncancel is not None and uncancel() == 0:
×
191
                        raise KeyboardInterrupt()
×
192
                raise  # CancelledError
×
193
            finally:
194
                if (
6✔
195
                    sigint_handler is not None
196
                    and signal.getsignal(signal.SIGINT) is sigint_handler
197
                ):
198
                    signal.signal(signal.SIGINT, signal.default_int_handler)
6✔
199

200
        def _lazy_init(self) -> None:
6✔
201
            if self._state is _State.CLOSED:
6✔
202
                raise RuntimeError("Runner is closed")
×
203
            if self._state is _State.INITIALIZED:
6✔
204
                return
6✔
205
            if self._loop_factory is None:
6✔
206
                self._loop = events.new_event_loop()
6✔
207
                if not self._set_event_loop:
6✔
208
                    # Call set_event_loop only once to avoid calling
209
                    # attach_loop multiple times on child watchers
210
                    events.set_event_loop(self._loop)
6✔
211
                    self._set_event_loop = True
6✔
212
            else:
213
                self._loop = self._loop_factory()
4✔
214
            if self._debug is not None:
6✔
215
                self._loop.set_debug(self._debug)
6✔
216
            self._context = contextvars.copy_context()
6✔
217
            self._state = _State.INITIALIZED
6✔
218

219
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
220
            self._interrupt_count += 1
×
221
            if self._interrupt_count == 1 and not main_task.done():
×
222
                main_task.cancel()
×
223
                # wakeup loop if it is blocked by select() with long timeout
224
                self._loop.call_soon_threadsafe(lambda: None)
×
225
                return
×
226
            raise KeyboardInterrupt()
×
227

228
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
229
        to_cancel = tasks.all_tasks(loop)
6✔
230
        if not to_cancel:
6✔
231
            return
6✔
232

233
        for task in to_cancel:
6✔
234
            task.cancel()
6✔
235

236
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
237

238
        for task in to_cancel:
6✔
239
            if task.cancelled():
6✔
240
                continue
6✔
241
            if task.exception() is not None:
6✔
242
                loop.call_exception_handler(
2✔
243
                    {
244
                        "message": "unhandled exception during asyncio.run() shutdown",
245
                        "exception": task.exception(),
246
                        "task": task,
247
                    }
248
                )
249

250
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
251
        """Schedule the shutdown of the default executor."""
252

253
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
254
            try:
3✔
255
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
256
                loop.call_soon_threadsafe(future.set_result, None)
3✔
257
            except Exception as ex:
×
258
                loop.call_soon_threadsafe(future.set_exception, ex)
×
259

260
        loop._executor_shutdown_called = True
3✔
261
        if loop._default_executor is None:
3✔
262
            return
3✔
263
        future = loop.create_future()
3✔
264
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
265
        thread.start()
3✔
266
        try:
3✔
267
            await future
3✔
268
        finally:
269
            thread.join()
3✔
270

271

272
T_Retval = TypeVar("T_Retval")
10✔
273
T_contra = TypeVar("T_contra", contravariant=True)
10✔
274

275
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
276

277

278
def find_root_task() -> asyncio.Task:
10✔
279
    root_task = _root_task.get(None)
10✔
280
    if root_task is not None and not root_task.done():
10✔
281
        return root_task
10✔
282

283
    # Look for a task that has been started via run_until_complete()
284
    for task in all_tasks():
10✔
285
        if task._callbacks and not task.done():
10✔
286
            callbacks = [cb for cb, context in task._callbacks]
10✔
287
            for cb in callbacks:
10✔
288
                if (
10✔
289
                    cb is _run_until_complete_cb
290
                    or getattr(cb, "__module__", None) == "uvloop.loop"
291
                ):
292
                    _root_task.set(task)
10✔
293
                    return task
10✔
294

295
    # Look up the topmost task in the AnyIO task tree, if possible
296
    task = cast(asyncio.Task, current_task())
10✔
297
    state = _task_states.get(task)
10✔
298
    if state:
10✔
299
        cancel_scope = state.cancel_scope
10✔
300
        while cancel_scope and cancel_scope._parent_scope is not None:
10✔
301
            cancel_scope = cancel_scope._parent_scope
×
302

303
        if cancel_scope is not None:
10✔
304
            return cast(asyncio.Task, cancel_scope._host_task)
10✔
305

306
    return task
×
307

308

309
def get_callable_name(func: Callable) -> str:
10✔
310
    module = getattr(func, "__module__", None)
10✔
311
    qualname = getattr(func, "__qualname__", None)
10✔
312
    return ".".join([x for x in (module, qualname) if x])
10✔
313

314

315
#
316
# Event loop
317
#
318

319
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
10✔
320

321

322
def _task_started(task: asyncio.Task) -> bool:
10✔
323
    """Return ``True`` if the task has been started and has not finished."""
324
    try:
10✔
325
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
326
    except AttributeError:
×
327
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
328
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
329

330

331
#
332
# Timeouts and cancellation
333
#
334

335

336
class CancelScope(BaseCancelScope):
10✔
337
    def __new__(
10✔
338
        cls, *, deadline: float = math.inf, shield: bool = False
339
    ) -> CancelScope:
340
        return object.__new__(cls)
10✔
341

342
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
343
        self._deadline = deadline
10✔
344
        self._shield = shield
10✔
345
        self._parent_scope: CancelScope | None = None
10✔
346
        self._cancel_called = False
10✔
347
        self._cancelled_caught = False
10✔
348
        self._active = False
10✔
349
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
350
        self._cancel_handle: asyncio.Handle | None = None
10✔
351
        self._tasks: set[asyncio.Task] = set()
10✔
352
        self._host_task: asyncio.Task | None = None
10✔
353
        self._cancel_calls: int = 0
10✔
354
        self._cancelling: int | None = None
10✔
355

356
    def __enter__(self) -> CancelScope:
10✔
357
        if self._active:
10✔
358
            raise RuntimeError(
×
359
                "Each CancelScope may only be used for a single 'with' block"
360
            )
361

362
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
363
        self._tasks.add(host_task)
10✔
364
        try:
10✔
365
            task_state = _task_states[host_task]
10✔
366
        except KeyError:
10✔
367
            task_state = TaskState(None, self)
10✔
368
            _task_states[host_task] = task_state
10✔
369
        else:
370
            self._parent_scope = task_state.cancel_scope
10✔
371
            task_state.cancel_scope = self
10✔
372

373
        self._timeout()
10✔
374
        self._active = True
10✔
375
        if sys.version_info >= (3, 11):
10✔
376
            self._cancelling = self._host_task.cancelling()
4✔
377

378
        # Start cancelling the host task if the scope was cancelled before entering
379
        if self._cancel_called:
10✔
380
            self._deliver_cancellation()
10✔
381

382
        return self
10✔
383

384
    def __exit__(
10✔
385
        self,
386
        exc_type: type[BaseException] | None,
387
        exc_val: BaseException | None,
388
        exc_tb: TracebackType | None,
389
    ) -> bool | None:
390
        if not self._active:
10✔
391
            raise RuntimeError("This cancel scope is not active")
10✔
392
        if current_task() is not self._host_task:
10✔
393
            raise RuntimeError(
10✔
394
                "Attempted to exit cancel scope in a different task than it was "
395
                "entered in"
396
            )
397

398
        assert self._host_task is not None
10✔
399
        host_task_state = _task_states.get(self._host_task)
10✔
400
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
401
            raise RuntimeError(
10✔
402
                "Attempted to exit a cancel scope that isn't the current tasks's "
403
                "current cancel scope"
404
            )
405

406
        self._active = False
10✔
407
        if self._timeout_handle:
10✔
408
            self._timeout_handle.cancel()
10✔
409
            self._timeout_handle = None
10✔
410

411
        self._tasks.remove(self._host_task)
10✔
412

413
        host_task_state.cancel_scope = self._parent_scope
10✔
414

415
        # Restart the cancellation effort in the farthest directly cancelled parent
416
        # scope if this one was shielded
417
        if self._shield:
10✔
418
            self._deliver_cancellation_to_parent()
10✔
419

420
        if self._cancel_called and exc_val is not None:
10✔
421
            for exc in iterate_exceptions(exc_val):
10✔
422
                if isinstance(exc, CancelledError):
10✔
423
                    self._cancelled_caught = self._uncancel(exc)
10✔
424
                    if self._cancelled_caught:
10✔
425
                        break
10✔
426

427
            return self._cancelled_caught
10✔
428

429
        return None
10✔
430

431
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
10✔
432
        if sys.version_info < (3, 9) or self._host_task is None:
10✔
433
            self._cancel_calls = 0
3✔
434
            return True
3✔
435

436
        # Undo all cancellations done by this scope
437
        if self._cancelling is not None:
7✔
438
            while self._cancel_calls:
4✔
439
                self._cancel_calls -= 1
4✔
440
                if self._host_task.uncancel() <= self._cancelling:
4✔
441
                    return True
4✔
442

443
        self._cancel_calls = 0
7✔
444
        return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
7✔
445

446
    def _timeout(self) -> None:
10✔
447
        if self._deadline != math.inf:
10✔
448
            loop = get_running_loop()
10✔
449
            if loop.time() >= self._deadline:
10✔
450
                self.cancel()
10✔
451
            else:
452
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
453

454
    def _deliver_cancellation(self) -> None:
10✔
455
        """
456
        Deliver cancellation to directly contained tasks and nested cancel scopes.
457

458
        Schedule another run at the end if we still have tasks eligible for
459
        cancellation.
460
        """
461
        should_retry = False
10✔
462
        current = current_task()
10✔
463
        for task in self._tasks:
10✔
464
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
465
                continue
10✔
466

467
            # The task is eligible for cancellation if it has started and is not in a
468
            # cancel scope shielded from this one
469
            cancel_scope = _task_states[task].cancel_scope
10✔
470
            while cancel_scope is not self:
10✔
471
                if cancel_scope is None or cancel_scope._shield:
10✔
472
                    break
10✔
473
                else:
474
                    cancel_scope = cancel_scope._parent_scope
10✔
475
            else:
476
                should_retry = True
10✔
477
                if task is not current and (
10✔
478
                    task is self._host_task or _task_started(task)
479
                ):
480
                    waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
481
                    if not isinstance(waiter, asyncio.Future) or not waiter.done():
10✔
482
                        self._cancel_calls += 1
10✔
483
                        if sys.version_info >= (3, 9):
10✔
484
                            task.cancel(f"Cancelled by cancel scope {id(self):x}")
7✔
485
                        else:
486
                            task.cancel()
3✔
487

488
        # Schedule another callback if there are still tasks left
489
        if should_retry:
10✔
490
            self._cancel_handle = get_running_loop().call_soon(
10✔
491
                self._deliver_cancellation
492
            )
493
        else:
494
            self._cancel_handle = None
10✔
495

496
    def _deliver_cancellation_to_parent(self) -> None:
10✔
497
        """Start cancellation effort in the farthest directly cancelled parent scope"""
498
        scope = self._parent_scope
10✔
499
        scope_to_cancel: CancelScope | None = None
10✔
500
        while scope is not None:
10✔
501
            if scope._cancel_called and scope._cancel_handle is None:
10✔
502
                scope_to_cancel = scope
10✔
503

504
            # No point in looking beyond any shielded scope
505
            if scope._shield:
10✔
506
                break
10✔
507

508
            scope = scope._parent_scope
10✔
509

510
        if scope_to_cancel is not None:
10✔
511
            scope_to_cancel._deliver_cancellation()
10✔
512

513
    def _parent_cancelled(self) -> bool:
10✔
514
        # Check whether any parent has been cancelled
515
        cancel_scope = self._parent_scope
×
516
        while cancel_scope is not None and not cancel_scope._shield:
×
517
            if cancel_scope._cancel_called:
×
518
                return True
×
519
            else:
520
                cancel_scope = cancel_scope._parent_scope
×
521

522
        return False
×
523

524
    def cancel(self) -> None:
10✔
525
        if not self._cancel_called:
10✔
526
            if self._timeout_handle:
10✔
527
                self._timeout_handle.cancel()
10✔
528
                self._timeout_handle = None
10✔
529

530
            self._cancel_called = True
10✔
531
            if self._host_task is not None:
10✔
532
                self._deliver_cancellation()
10✔
533

534
    @property
10✔
535
    def deadline(self) -> float:
10✔
536
        return self._deadline
10✔
537

538
    @deadline.setter
10✔
539
    def deadline(self, value: float) -> None:
10✔
540
        self._deadline = float(value)
10✔
541
        if self._timeout_handle is not None:
10✔
542
            self._timeout_handle.cancel()
10✔
543
            self._timeout_handle = None
10✔
544

545
        if self._active and not self._cancel_called:
10✔
546
            self._timeout()
10✔
547

548
    @property
10✔
549
    def cancel_called(self) -> bool:
10✔
550
        return self._cancel_called
10✔
551

552
    @property
10✔
553
    def cancelled_caught(self) -> bool:
10✔
554
        return self._cancelled_caught
10✔
555

556
    @property
10✔
557
    def shield(self) -> bool:
10✔
558
        return self._shield
10✔
559

560
    @shield.setter
10✔
561
    def shield(self, value: bool) -> None:
10✔
562
        if self._shield != value:
10✔
563
            self._shield = value
10✔
564
            if not value:
10✔
565
                self._deliver_cancellation_to_parent()
10✔
566

567

568
#
569
# Task states
570
#
571

572

573
class TaskState:
10✔
574
    """
575
    Encapsulates auxiliary task information that cannot be added to the Task instance
576
    itself because there are no guarantees about its implementation.
577
    """
578

579
    __slots__ = "parent_id", "cancel_scope"
10✔
580

581
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
582
        self.parent_id = parent_id
10✔
583
        self.cancel_scope = cancel_scope
10✔
584

585

586
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
587

588

589
#
590
# Task groups
591
#
592

593

594
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
595
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
596
        self._future = future
10✔
597
        self._parent_id = parent_id
10✔
598

599
    def started(self, value: T_contra | None = None) -> None:
10✔
600
        try:
10✔
601
            self._future.set_result(value)
10✔
602
        except asyncio.InvalidStateError:
10✔
603
            raise RuntimeError(
10✔
604
                "called 'started' twice on the same task status"
605
            ) from None
606

607
        task = cast(asyncio.Task, current_task())
10✔
608
        _task_states[task].parent_id = self._parent_id
10✔
609

610

611
def iterate_exceptions(
10✔
612
    exception: BaseException,
613
) -> Generator[BaseException, None, None]:
614
    if isinstance(exception, BaseExceptionGroup):
10✔
615
        for exc in exception.exceptions:
10✔
616
            yield from iterate_exceptions(exc)
10✔
617
    else:
618
        yield exception
10✔
619

620

621
class TaskGroup(abc.TaskGroup):
10✔
622
    def __init__(self) -> None:
10✔
623
        self.cancel_scope: CancelScope = CancelScope()
10✔
624
        self._active = False
10✔
625
        self._exceptions: list[BaseException] = []
10✔
626

627
    async def __aenter__(self) -> TaskGroup:
10✔
628
        self.cancel_scope.__enter__()
10✔
629
        self._active = True
10✔
630
        return self
10✔
631

632
    async def __aexit__(
10✔
633
        self,
634
        exc_type: type[BaseException] | None,
635
        exc_val: BaseException | None,
636
        exc_tb: TracebackType | None,
637
    ) -> bool | None:
638
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
639
        if exc_val is not None:
10✔
640
            self.cancel_scope.cancel()
10✔
641
            if not isinstance(exc_val, CancelledError):
10✔
642
                self._exceptions.append(exc_val)
10✔
643

644
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
10✔
645
        while self.cancel_scope._tasks:
10✔
646
            try:
10✔
647
                await asyncio.wait(self.cancel_scope._tasks)
10✔
648
            except CancelledError as exc:
10✔
649
                # This task was cancelled natively; reraise the CancelledError later
650
                # unless this task was already interrupted by another exception
651
                self.cancel_scope.cancel()
10✔
652
                if cancelled_exc_while_waiting_tasks is None:
10✔
653
                    cancelled_exc_while_waiting_tasks = exc
10✔
654

655
        self._active = False
10✔
656
        if self._exceptions:
10✔
657
            raise BaseExceptionGroup(
10✔
658
                "unhandled errors in a TaskGroup", self._exceptions
659
            )
660

661
        # Raise the CancelledError received while waiting for child tasks to exit,
662
        # unless the context manager itself was previously exited with another
663
        # exception, or if any of the  child tasks raised an exception other than
664
        # CancelledError
665
        if cancelled_exc_while_waiting_tasks:
10✔
666
            if exc_val is None or ignore_exception:
10✔
667
                raise cancelled_exc_while_waiting_tasks
10✔
668

669
        return ignore_exception
10✔
670

671
    def _spawn(
10✔
672
        self,
673
        func: Callable[..., Awaitable[Any]],
674
        args: tuple,
675
        name: object,
676
        task_status_future: asyncio.Future | None = None,
677
    ) -> asyncio.Task:
678
        def task_done(_task: asyncio.Task) -> None:
10✔
679
            assert _task in self.cancel_scope._tasks
10✔
680
            self.cancel_scope._tasks.remove(_task)
10✔
681
            del _task_states[_task]
10✔
682

683
            try:
10✔
684
                exc = _task.exception()
10✔
685
            except CancelledError as e:
10✔
686
                while isinstance(e.__context__, CancelledError):
10✔
687
                    e = e.__context__
3✔
688

689
                exc = e
10✔
690

691
            if exc is not None:
10✔
692
                if task_status_future is None or task_status_future.done():
10✔
693
                    if not isinstance(exc, CancelledError):
10✔
694
                        self._exceptions.append(exc)
10✔
695

696
                    self.cancel_scope.cancel()
10✔
697
                else:
698
                    task_status_future.set_exception(exc)
10✔
699
            elif task_status_future is not None and not task_status_future.done():
10✔
700
                task_status_future.set_exception(
10✔
701
                    RuntimeError("Child exited without calling task_status.started()")
702
                )
703

704
        if not self._active:
10✔
705
            raise RuntimeError(
10✔
706
                "This task group is not active; no new tasks can be started."
707
            )
708

709
        kwargs = {}
10✔
710
        if task_status_future:
10✔
711
            parent_id = id(current_task())
10✔
712
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
713
                task_status_future, id(self.cancel_scope._host_task)
714
            )
715
        else:
716
            parent_id = id(self.cancel_scope._host_task)
10✔
717

718
        coro = func(*args, **kwargs)
10✔
719
        if not iscoroutine(coro):
10✔
720
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
10✔
721
            raise TypeError(
10✔
722
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
723
                f"the return value ({coro!r}) is not a coroutine object"
724
            )
725

726
        name = get_callable_name(func) if name is None else str(name)
10✔
727
        task = create_task(coro, name=name)
10✔
728
        task.add_done_callback(task_done)
10✔
729

730
        # Make the spawned task inherit the task group's cancel scope
731
        _task_states[task] = TaskState(
10✔
732
            parent_id=parent_id, cancel_scope=self.cancel_scope
733
        )
734
        self.cancel_scope._tasks.add(task)
10✔
735
        return task
10✔
736

737
    def start_soon(
10✔
738
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
739
    ) -> None:
740
        self._spawn(func, args, name)
10✔
741

742
    async def start(
10✔
743
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
744
    ) -> None:
745
        future: asyncio.Future = asyncio.Future()
10✔
746
        task = self._spawn(func, args, name, future)
10✔
747

748
        # If the task raises an exception after sending a start value without a switch
749
        # point between, the task group is cancelled and this method never proceeds to
750
        # process the completed future. That's why we have to have a shielded cancel
751
        # scope here.
752
        try:
10✔
753
            return await future
10✔
754
        except CancelledError:
10✔
755
            # Cancel the task and wait for it to exit before returning
756
            task.cancel()
10✔
757
            with CancelScope(shield=True), suppress(CancelledError):
10✔
758
                await task
10✔
759

760
            raise
10✔
761

762

763
#
764
# Threads
765
#
766

767
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
768

769

770
class WorkerThread(Thread):
10✔
771
    MAX_IDLE_TIME = 10  # seconds
10✔
772

773
    def __init__(
10✔
774
        self,
775
        root_task: asyncio.Task,
776
        workers: set[WorkerThread],
777
        idle_workers: deque[WorkerThread],
778
    ):
779
        super().__init__(name="AnyIO worker thread")
10✔
780
        self.root_task = root_task
10✔
781
        self.workers = workers
10✔
782
        self.idle_workers = idle_workers
10✔
783
        self.loop = root_task._loop
10✔
784
        self.queue: Queue[
10✔
785
            tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
786
        ] = Queue(2)
787
        self.idle_since = AsyncIOBackend.current_time()
10✔
788
        self.stopping = False
10✔
789

790
    def _report_result(
10✔
791
        self, future: asyncio.Future, result: Any, exc: BaseException | None
792
    ) -> None:
793
        self.idle_since = AsyncIOBackend.current_time()
10✔
794
        if not self.stopping:
10✔
795
            self.idle_workers.append(self)
10✔
796

797
        if not future.cancelled():
10✔
798
            if exc is not None:
10✔
799
                if isinstance(exc, StopIteration):
10✔
800
                    new_exc = RuntimeError("coroutine raised StopIteration")
10✔
801
                    new_exc.__cause__ = exc
10✔
802
                    exc = new_exc
10✔
803

804
                future.set_exception(exc)
10✔
805
            else:
806
                future.set_result(result)
10✔
807

808
    def run(self) -> None:
10✔
809
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
810
            while True:
6✔
811
                item = self.queue.get()
10✔
812
                if item is None:
10✔
813
                    # Shutdown command received
814
                    return
10✔
815

816
                context, func, args, future, cancel_scope = item
10✔
817
                if not future.cancelled():
10✔
818
                    result = None
10✔
819
                    exception: BaseException | None = None
10✔
820
                    threadlocals.current_cancel_scope = cancel_scope
10✔
821
                    try:
10✔
822
                        result = context.run(func, *args)
10✔
823
                    except BaseException as exc:
10✔
824
                        exception = exc
10✔
825
                    finally:
826
                        del threadlocals.current_cancel_scope
10✔
827

828
                    if not self.loop.is_closed():
10✔
829
                        self.loop.call_soon_threadsafe(
10✔
830
                            self._report_result, future, result, exception
831
                        )
832

833
                self.queue.task_done()
10✔
834

835
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
836
        self.stopping = True
10✔
837
        self.queue.put_nowait(None)
10✔
838
        self.workers.discard(self)
10✔
839
        try:
10✔
840
            self.idle_workers.remove(self)
10✔
841
        except ValueError:
10✔
842
            pass
10✔
843

844

845
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
846
    "_threadpool_idle_workers"
847
)
848
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
849

850

851
class BlockingPortal(abc.BlockingPortal):
10✔
852
    def __new__(cls) -> BlockingPortal:
10✔
853
        return object.__new__(cls)
10✔
854

855
    def __init__(self) -> None:
10✔
856
        super().__init__()
10✔
857
        self._loop = get_running_loop()
10✔
858

859
    def _spawn_task_from_thread(
10✔
860
        self,
861
        func: Callable,
862
        args: tuple[Any, ...],
863
        kwargs: dict[str, Any],
864
        name: object,
865
        future: Future,
866
    ) -> None:
867
        AsyncIOBackend.run_sync_from_thread(
10✔
868
            partial(self._task_group.start_soon, name=name),
869
            (self._call_func, func, args, kwargs, future),
870
            self._loop,
871
        )
872

873

874
#
875
# Subprocesses
876
#
877

878

879
@dataclass(eq=False)
10✔
880
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
881
    _stream: asyncio.StreamReader
10✔
882

883
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
884
        data = await self._stream.read(max_bytes)
10✔
885
        if data:
10✔
886
            return data
10✔
887
        else:
888
            raise EndOfStream
10✔
889

890
    async def aclose(self) -> None:
10✔
891
        self._stream.feed_eof()
10✔
892

893

894
@dataclass(eq=False)
10✔
895
class StreamWriterWrapper(abc.ByteSendStream):
10✔
896
    _stream: asyncio.StreamWriter
10✔
897

898
    async def send(self, item: bytes) -> None:
10✔
899
        self._stream.write(item)
10✔
900
        await self._stream.drain()
10✔
901

902
    async def aclose(self) -> None:
10✔
903
        self._stream.close()
10✔
904

905

906
@dataclass(eq=False)
10✔
907
class Process(abc.Process):
10✔
908
    _process: asyncio.subprocess.Process
10✔
909
    _stdin: StreamWriterWrapper | None
10✔
910
    _stdout: StreamReaderWrapper | None
10✔
911
    _stderr: StreamReaderWrapper | None
10✔
912

913
    async def aclose(self) -> None:
10✔
914
        if self._stdin:
10✔
915
            await self._stdin.aclose()
10✔
916
        if self._stdout:
10✔
917
            await self._stdout.aclose()
10✔
918
        if self._stderr:
10✔
919
            await self._stderr.aclose()
10✔
920

921
        await self.wait()
10✔
922

923
    async def wait(self) -> int:
10✔
924
        return await self._process.wait()
10✔
925

926
    def terminate(self) -> None:
10✔
927
        self._process.terminate()
8✔
928

929
    def kill(self) -> None:
10✔
930
        self._process.kill()
10✔
931

932
    def send_signal(self, signal: int) -> None:
10✔
933
        self._process.send_signal(signal)
×
934

935
    @property
10✔
936
    def pid(self) -> int:
10✔
937
        return self._process.pid
×
938

939
    @property
10✔
940
    def returncode(self) -> int | None:
10✔
941
        return self._process.returncode
10✔
942

943
    @property
10✔
944
    def stdin(self) -> abc.ByteSendStream | None:
10✔
945
        return self._stdin
10✔
946

947
    @property
10✔
948
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
949
        return self._stdout
10✔
950

951
    @property
10✔
952
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
953
        return self._stderr
10✔
954

955

956
def _forcibly_shutdown_process_pool_on_exit(
10✔
957
    workers: set[Process], _task: object
958
) -> None:
959
    """
960
    Forcibly shuts down worker processes belonging to this event loop."""
961
    child_watcher: asyncio.AbstractChildWatcher | None = None
10✔
962
    if sys.version_info < (3, 12):
10✔
963
        try:
9✔
964
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
9✔
965
        except NotImplementedError:
2✔
966
            pass
2✔
967

968
    # Close as much as possible (w/o async/await) to avoid warnings
969
    for process in workers:
10✔
970
        if process.returncode is None:
10✔
971
            continue
10✔
972

973
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
974
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
975
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
976
        process.kill()
×
977
        if child_watcher:
×
978
            child_watcher.remove_child_handler(process.pid)
×
979

980

981
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
982
    """
983
    Shuts down worker processes belonging to this event loop.
984

985
    NOTE: this only works when the event loop was started using asyncio.run() or
986
    anyio.run().
987

988
    """
989
    process: abc.Process
990
    try:
10✔
991
        await sleep(math.inf)
10✔
992
    except asyncio.CancelledError:
10✔
993
        for process in workers:
10✔
994
            if process.returncode is None:
10✔
995
                process.kill()
10✔
996

997
        for process in workers:
10✔
998
            await process.aclose()
10✔
999

1000

1001
#
1002
# Sockets and networking
1003
#
1004

1005

1006
class StreamProtocol(asyncio.Protocol):
10✔
1007
    read_queue: deque[bytes]
10✔
1008
    read_event: asyncio.Event
10✔
1009
    write_event: asyncio.Event
10✔
1010
    exception: Exception | None = None
10✔
1011

1012
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1013
        self.read_queue = deque()
10✔
1014
        self.read_event = asyncio.Event()
10✔
1015
        self.write_event = asyncio.Event()
10✔
1016
        self.write_event.set()
10✔
1017
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1018

1019
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1020
        if exc:
10✔
1021
            self.exception = BrokenResourceError()
10✔
1022
            self.exception.__cause__ = exc
10✔
1023

1024
        self.read_event.set()
10✔
1025
        self.write_event.set()
10✔
1026

1027
    def data_received(self, data: bytes) -> None:
10✔
1028
        self.read_queue.append(data)
10✔
1029
        self.read_event.set()
10✔
1030

1031
    def eof_received(self) -> bool | None:
10✔
1032
        self.read_event.set()
10✔
1033
        return True
10✔
1034

1035
    def pause_writing(self) -> None:
10✔
1036
        self.write_event = asyncio.Event()
10✔
1037

1038
    def resume_writing(self) -> None:
10✔
1039
        self.write_event.set()
×
1040

1041

1042
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1043
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1044
    read_event: asyncio.Event
10✔
1045
    write_event: asyncio.Event
10✔
1046
    exception: Exception | None = None
10✔
1047

1048
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1049
        self.read_queue = deque(maxlen=100)  # arbitrary value
10✔
1050
        self.read_event = asyncio.Event()
10✔
1051
        self.write_event = asyncio.Event()
10✔
1052
        self.write_event.set()
10✔
1053

1054
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1055
        self.read_event.set()
10✔
1056
        self.write_event.set()
10✔
1057

1058
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1059
        addr = convert_ipv6_sockaddr(addr)
10✔
1060
        self.read_queue.append((data, addr))
10✔
1061
        self.read_event.set()
10✔
1062

1063
    def error_received(self, exc: Exception) -> None:
10✔
1064
        self.exception = exc
×
1065

1066
    def pause_writing(self) -> None:
10✔
1067
        self.write_event.clear()
×
1068

1069
    def resume_writing(self) -> None:
10✔
1070
        self.write_event.set()
×
1071

1072

1073
class SocketStream(abc.SocketStream):
10✔
1074
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1075
        self._transport = transport
10✔
1076
        self._protocol = protocol
10✔
1077
        self._receive_guard = ResourceGuard("reading from")
10✔
1078
        self._send_guard = ResourceGuard("writing to")
10✔
1079
        self._closed = False
10✔
1080

1081
    @property
10✔
1082
    def _raw_socket(self) -> socket.socket:
10✔
1083
        return self._transport.get_extra_info("socket")
10✔
1084

1085
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1086
        with self._receive_guard:
10✔
1087
            await AsyncIOBackend.checkpoint()
10✔
1088

1089
            if (
10✔
1090
                not self._protocol.read_event.is_set()
1091
                and not self._transport.is_closing()
1092
            ):
1093
                self._transport.resume_reading()
10✔
1094
                await self._protocol.read_event.wait()
10✔
1095
                self._transport.pause_reading()
10✔
1096

1097
            try:
10✔
1098
                chunk = self._protocol.read_queue.popleft()
10✔
1099
            except IndexError:
10✔
1100
                if self._closed:
10✔
1101
                    raise ClosedResourceError from None
10✔
1102
                elif self._protocol.exception:
10✔
1103
                    raise self._protocol.exception from None
10✔
1104
                else:
1105
                    raise EndOfStream from None
10✔
1106

1107
            if len(chunk) > max_bytes:
10✔
1108
                # Split the oversized chunk
1109
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1110
                self._protocol.read_queue.appendleft(leftover)
8✔
1111

1112
            # If the read queue is empty, clear the flag so that the next call will
1113
            # block until data is available
1114
            if not self._protocol.read_queue:
10✔
1115
                self._protocol.read_event.clear()
10✔
1116

1117
        return chunk
10✔
1118

1119
    async def send(self, item: bytes) -> None:
10✔
1120
        with self._send_guard:
10✔
1121
            await AsyncIOBackend.checkpoint()
10✔
1122

1123
            if self._closed:
10✔
1124
                raise ClosedResourceError
10✔
1125
            elif self._protocol.exception is not None:
10✔
1126
                raise self._protocol.exception
10✔
1127

1128
            try:
10✔
1129
                self._transport.write(item)
10✔
1130
            except RuntimeError as exc:
×
1131
                if self._transport.is_closing():
×
1132
                    raise BrokenResourceError from exc
×
1133
                else:
1134
                    raise
×
1135

1136
            await self._protocol.write_event.wait()
10✔
1137

1138
    async def send_eof(self) -> None:
10✔
1139
        try:
10✔
1140
            self._transport.write_eof()
10✔
1141
        except OSError:
×
1142
            pass
×
1143

1144
    async def aclose(self) -> None:
10✔
1145
        if not self._transport.is_closing():
10✔
1146
            self._closed = True
10✔
1147
            try:
10✔
1148
                self._transport.write_eof()
10✔
1149
            except OSError:
6✔
1150
                pass
6✔
1151

1152
            self._transport.close()
10✔
1153
            await sleep(0)
10✔
1154
            self._transport.abort()
10✔
1155

1156

1157
class _RawSocketMixin:
10✔
1158
    _receive_future: asyncio.Future | None = None
10✔
1159
    _send_future: asyncio.Future | None = None
10✔
1160
    _closing = False
10✔
1161

1162
    def __init__(self, raw_socket: socket.socket):
10✔
1163
        self.__raw_socket = raw_socket
8✔
1164
        self._receive_guard = ResourceGuard("reading from")
8✔
1165
        self._send_guard = ResourceGuard("writing to")
8✔
1166

1167
    @property
10✔
1168
    def _raw_socket(self) -> socket.socket:
10✔
1169
        return self.__raw_socket
8✔
1170

1171
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1172
        def callback(f: object) -> None:
8✔
1173
            del self._receive_future
8✔
1174
            loop.remove_reader(self.__raw_socket)
8✔
1175

1176
        f = self._receive_future = asyncio.Future()
8✔
1177
        loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1178
        f.add_done_callback(callback)
8✔
1179
        return f
8✔
1180

1181
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1182
        def callback(f: object) -> None:
8✔
1183
            del self._send_future
8✔
1184
            loop.remove_writer(self.__raw_socket)
8✔
1185

1186
        f = self._send_future = asyncio.Future()
8✔
1187
        loop.add_writer(self.__raw_socket, f.set_result, None)
8✔
1188
        f.add_done_callback(callback)
8✔
1189
        return f
8✔
1190

1191
    async def aclose(self) -> None:
10✔
1192
        if not self._closing:
8✔
1193
            self._closing = True
8✔
1194
            if self.__raw_socket.fileno() != -1:
8✔
1195
                self.__raw_socket.close()
8✔
1196

1197
            if self._receive_future:
8✔
1198
                self._receive_future.set_result(None)
8✔
1199
            if self._send_future:
8✔
1200
                self._send_future.set_result(None)
×
1201

1202

1203
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1204
    async def send_eof(self) -> None:
10✔
1205
        with self._send_guard:
8✔
1206
            self._raw_socket.shutdown(socket.SHUT_WR)
8✔
1207

1208
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1209
        loop = get_running_loop()
8✔
1210
        await AsyncIOBackend.checkpoint()
8✔
1211
        with self._receive_guard:
8✔
1212
            while True:
5✔
1213
                try:
8✔
1214
                    data = self._raw_socket.recv(max_bytes)
8✔
1215
                except BlockingIOError:
8✔
1216
                    await self._wait_until_readable(loop)
8✔
1217
                except OSError as exc:
8✔
1218
                    if self._closing:
8✔
1219
                        raise ClosedResourceError from None
8✔
1220
                    else:
1221
                        raise BrokenResourceError from exc
1✔
1222
                else:
1223
                    if not data:
8✔
1224
                        raise EndOfStream
8✔
1225

1226
                    return data
8✔
1227

1228
    async def send(self, item: bytes) -> None:
10✔
1229
        loop = get_running_loop()
8✔
1230
        await AsyncIOBackend.checkpoint()
8✔
1231
        with self._send_guard:
8✔
1232
            view = memoryview(item)
8✔
1233
            while view:
8✔
1234
                try:
8✔
1235
                    bytes_sent = self._raw_socket.send(view)
8✔
1236
                except BlockingIOError:
8✔
1237
                    await self._wait_until_writable(loop)
8✔
1238
                except OSError as exc:
8✔
1239
                    if self._closing:
8✔
1240
                        raise ClosedResourceError from None
8✔
1241
                    else:
1242
                        raise BrokenResourceError from exc
1✔
1243
                else:
1244
                    view = view[bytes_sent:]
8✔
1245

1246
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1247
        if not isinstance(msglen, int) or msglen < 0:
8✔
1248
            raise ValueError("msglen must be a non-negative integer")
8✔
1249
        if not isinstance(maxfds, int) or maxfds < 1:
8✔
1250
            raise ValueError("maxfds must be a positive integer")
8✔
1251

1252
        loop = get_running_loop()
8✔
1253
        fds = array.array("i")
8✔
1254
        await AsyncIOBackend.checkpoint()
8✔
1255
        with self._receive_guard:
8✔
1256
            while True:
5✔
1257
                try:
8✔
1258
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
8✔
1259
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1260
                    )
1261
                except BlockingIOError:
8✔
1262
                    await self._wait_until_readable(loop)
8✔
1263
                except OSError as exc:
×
1264
                    if self._closing:
×
1265
                        raise ClosedResourceError from None
×
1266
                    else:
1267
                        raise BrokenResourceError from exc
×
1268
                else:
1269
                    if not message and not ancdata:
8✔
1270
                        raise EndOfStream
×
1271

1272
                    break
5✔
1273

1274
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
8✔
1275
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
8✔
1276
                raise RuntimeError(
×
1277
                    f"Received unexpected ancillary data; message = {message!r}, "
1278
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1279
                )
1280

1281
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
8✔
1282

1283
        return message, list(fds)
8✔
1284

1285
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1286
        if not message:
8✔
1287
            raise ValueError("message must not be empty")
8✔
1288
        if not fds:
8✔
1289
            raise ValueError("fds must not be empty")
8✔
1290

1291
        loop = get_running_loop()
8✔
1292
        filenos: list[int] = []
8✔
1293
        for fd in fds:
8✔
1294
            if isinstance(fd, int):
8✔
1295
                filenos.append(fd)
×
1296
            elif isinstance(fd, IOBase):
8✔
1297
                filenos.append(fd.fileno())
8✔
1298

1299
        fdarray = array.array("i", filenos)
8✔
1300
        await AsyncIOBackend.checkpoint()
8✔
1301
        with self._send_guard:
8✔
1302
            while True:
5✔
1303
                try:
8✔
1304
                    # The ignore can be removed after mypy picks up
1305
                    # https://github.com/python/typeshed/pull/5545
1306
                    self._raw_socket.sendmsg(
8✔
1307
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1308
                    )
1309
                    break
8✔
1310
                except BlockingIOError:
×
1311
                    await self._wait_until_writable(loop)
×
1312
                except OSError as exc:
×
1313
                    if self._closing:
×
1314
                        raise ClosedResourceError from None
×
1315
                    else:
1316
                        raise BrokenResourceError from exc
×
1317

1318

1319
class TCPSocketListener(abc.SocketListener):
10✔
1320
    _accept_scope: CancelScope | None = None
10✔
1321
    _closed = False
10✔
1322

1323
    def __init__(self, raw_socket: socket.socket):
10✔
1324
        self.__raw_socket = raw_socket
10✔
1325
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1326
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1327

1328
    @property
10✔
1329
    def _raw_socket(self) -> socket.socket:
10✔
1330
        return self.__raw_socket
10✔
1331

1332
    async def accept(self) -> abc.SocketStream:
10✔
1333
        if self._closed:
10✔
1334
            raise ClosedResourceError
10✔
1335

1336
        with self._accept_guard:
10✔
1337
            await AsyncIOBackend.checkpoint()
10✔
1338
            with CancelScope() as self._accept_scope:
10✔
1339
                try:
10✔
1340
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1341
                except asyncio.CancelledError:
10✔
1342
                    # Workaround for https://bugs.python.org/issue41317
1343
                    try:
10✔
1344
                        self._loop.remove_reader(self._raw_socket)
10✔
1345
                    except (ValueError, NotImplementedError):
2✔
1346
                        pass
2✔
1347

1348
                    if self._closed:
10✔
1349
                        raise ClosedResourceError from None
10✔
1350

1351
                    raise
10✔
1352
                finally:
1353
                    self._accept_scope = None
10✔
1354

1355
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1356
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1357
            StreamProtocol, client_sock
1358
        )
1359
        return SocketStream(transport, protocol)
10✔
1360

1361
    async def aclose(self) -> None:
10✔
1362
        if self._closed:
10✔
1363
            return
10✔
1364

1365
        self._closed = True
10✔
1366
        if self._accept_scope:
10✔
1367
            # Workaround for https://bugs.python.org/issue41317
1368
            try:
10✔
1369
                self._loop.remove_reader(self._raw_socket)
10✔
1370
            except (ValueError, NotImplementedError):
2✔
1371
                pass
2✔
1372

1373
            self._accept_scope.cancel()
10✔
1374
            await sleep(0)
10✔
1375

1376
        self._raw_socket.close()
10✔
1377

1378

1379
class UNIXSocketListener(abc.SocketListener):
10✔
1380
    def __init__(self, raw_socket: socket.socket):
10✔
1381
        self.__raw_socket = raw_socket
8✔
1382
        self._loop = get_running_loop()
8✔
1383
        self._accept_guard = ResourceGuard("accepting connections from")
8✔
1384
        self._closed = False
8✔
1385

1386
    async def accept(self) -> abc.SocketStream:
10✔
1387
        await AsyncIOBackend.checkpoint()
8✔
1388
        with self._accept_guard:
8✔
1389
            while True:
5✔
1390
                try:
8✔
1391
                    client_sock, _ = self.__raw_socket.accept()
8✔
1392
                    client_sock.setblocking(False)
8✔
1393
                    return UNIXSocketStream(client_sock)
8✔
1394
                except BlockingIOError:
8✔
1395
                    f: asyncio.Future = asyncio.Future()
8✔
1396
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1397
                    f.add_done_callback(
8✔
1398
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1399
                    )
1400
                    await f
8✔
1401
                except OSError as exc:
×
1402
                    if self._closed:
×
1403
                        raise ClosedResourceError from None
×
1404
                    else:
1405
                        raise BrokenResourceError from exc
1✔
1406

1407
    async def aclose(self) -> None:
10✔
1408
        self._closed = True
8✔
1409
        self.__raw_socket.close()
8✔
1410

1411
    @property
10✔
1412
    def _raw_socket(self) -> socket.socket:
10✔
1413
        return self.__raw_socket
8✔
1414

1415

1416
class UDPSocket(abc.UDPSocket):
10✔
1417
    def __init__(
10✔
1418
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1419
    ):
1420
        self._transport = transport
10✔
1421
        self._protocol = protocol
10✔
1422
        self._receive_guard = ResourceGuard("reading from")
10✔
1423
        self._send_guard = ResourceGuard("writing to")
10✔
1424
        self._closed = False
10✔
1425

1426
    @property
10✔
1427
    def _raw_socket(self) -> socket.socket:
10✔
1428
        return self._transport.get_extra_info("socket")
10✔
1429

1430
    async def aclose(self) -> None:
10✔
1431
        if not self._transport.is_closing():
10✔
1432
            self._closed = True
10✔
1433
            self._transport.close()
10✔
1434

1435
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1436
        with self._receive_guard:
10✔
1437
            await AsyncIOBackend.checkpoint()
10✔
1438

1439
            # If the buffer is empty, ask for more data
1440
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1441
                self._protocol.read_event.clear()
10✔
1442
                await self._protocol.read_event.wait()
10✔
1443

1444
            try:
10✔
1445
                return self._protocol.read_queue.popleft()
10✔
1446
            except IndexError:
10✔
1447
                if self._closed:
10✔
1448
                    raise ClosedResourceError from None
10✔
1449
                else:
1450
                    raise BrokenResourceError from None
1✔
1451

1452
    async def send(self, item: UDPPacketType) -> None:
10✔
1453
        with self._send_guard:
10✔
1454
            await AsyncIOBackend.checkpoint()
10✔
1455
            await self._protocol.write_event.wait()
10✔
1456
            if self._closed:
10✔
1457
                raise ClosedResourceError
10✔
1458
            elif self._transport.is_closing():
10✔
1459
                raise BrokenResourceError
×
1460
            else:
1461
                self._transport.sendto(*item)
10✔
1462

1463

1464
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1465
    def __init__(
10✔
1466
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1467
    ):
1468
        self._transport = transport
10✔
1469
        self._protocol = protocol
10✔
1470
        self._receive_guard = ResourceGuard("reading from")
10✔
1471
        self._send_guard = ResourceGuard("writing to")
10✔
1472
        self._closed = False
10✔
1473

1474
    @property
10✔
1475
    def _raw_socket(self) -> socket.socket:
10✔
1476
        return self._transport.get_extra_info("socket")
10✔
1477

1478
    async def aclose(self) -> None:
10✔
1479
        if not self._transport.is_closing():
10✔
1480
            self._closed = True
10✔
1481
            self._transport.close()
10✔
1482

1483
    async def receive(self) -> bytes:
10✔
1484
        with self._receive_guard:
10✔
1485
            await AsyncIOBackend.checkpoint()
10✔
1486

1487
            # If the buffer is empty, ask for more data
1488
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1489
                self._protocol.read_event.clear()
10✔
1490
                await self._protocol.read_event.wait()
10✔
1491

1492
            try:
10✔
1493
                packet = self._protocol.read_queue.popleft()
10✔
1494
            except IndexError:
10✔
1495
                if self._closed:
10✔
1496
                    raise ClosedResourceError from None
10✔
1497
                else:
1498
                    raise BrokenResourceError from None
×
1499

1500
            return packet[0]
10✔
1501

1502
    async def send(self, item: bytes) -> None:
10✔
1503
        with self._send_guard:
10✔
1504
            await AsyncIOBackend.checkpoint()
10✔
1505
            await self._protocol.write_event.wait()
10✔
1506
            if self._closed:
10✔
1507
                raise ClosedResourceError
10✔
1508
            elif self._transport.is_closing():
10✔
1509
                raise BrokenResourceError
×
1510
            else:
1511
                self._transport.sendto(item)
10✔
1512

1513

1514
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1515
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1516
        loop = get_running_loop()
8✔
1517
        await AsyncIOBackend.checkpoint()
8✔
1518
        with self._receive_guard:
8✔
1519
            while True:
5✔
1520
                try:
8✔
1521
                    data = self._raw_socket.recvfrom(65536)
8✔
1522
                except BlockingIOError:
8✔
1523
                    await self._wait_until_readable(loop)
8✔
1524
                except OSError as exc:
8✔
1525
                    if self._closing:
8✔
1526
                        raise ClosedResourceError from None
8✔
1527
                    else:
1528
                        raise BrokenResourceError from exc
1✔
1529
                else:
1530
                    return data
8✔
1531

1532
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1533
        loop = get_running_loop()
8✔
1534
        await AsyncIOBackend.checkpoint()
8✔
1535
        with self._send_guard:
8✔
1536
            while True:
5✔
1537
                try:
8✔
1538
                    self._raw_socket.sendto(*item)
8✔
1539
                except BlockingIOError:
8✔
1540
                    await self._wait_until_writable(loop)
×
1541
                except OSError as exc:
8✔
1542
                    if self._closing:
8✔
1543
                        raise ClosedResourceError from None
8✔
1544
                    else:
1545
                        raise BrokenResourceError from exc
1✔
1546
                else:
1547
                    return
8✔
1548

1549

1550
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1551
    async def receive(self) -> bytes:
10✔
1552
        loop = get_running_loop()
8✔
1553
        await AsyncIOBackend.checkpoint()
8✔
1554
        with self._receive_guard:
8✔
1555
            while True:
5✔
1556
                try:
8✔
1557
                    data = self._raw_socket.recv(65536)
8✔
1558
                except BlockingIOError:
8✔
1559
                    await self._wait_until_readable(loop)
8✔
1560
                except OSError as exc:
8✔
1561
                    if self._closing:
8✔
1562
                        raise ClosedResourceError from None
8✔
1563
                    else:
1564
                        raise BrokenResourceError from exc
1✔
1565
                else:
1566
                    return data
8✔
1567

1568
    async def send(self, item: bytes) -> None:
10✔
1569
        loop = get_running_loop()
8✔
1570
        await AsyncIOBackend.checkpoint()
8✔
1571
        with self._send_guard:
8✔
1572
            while True:
5✔
1573
                try:
8✔
1574
                    self._raw_socket.send(item)
8✔
1575
                except BlockingIOError:
8✔
1576
                    await self._wait_until_writable(loop)
×
1577
                except OSError as exc:
8✔
1578
                    if self._closing:
8✔
1579
                        raise ClosedResourceError from None
8✔
1580
                    else:
1581
                        raise BrokenResourceError from exc
1✔
1582
                else:
1583
                    return
8✔
1584

1585

1586
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1587
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1588

1589

1590
#
1591
# Synchronization
1592
#
1593

1594

1595
class Event(BaseEvent):
10✔
1596
    def __new__(cls) -> Event:
10✔
1597
        return object.__new__(cls)
10✔
1598

1599
    def __init__(self) -> None:
10✔
1600
        self._event = asyncio.Event()
10✔
1601

1602
    def set(self) -> None:
10✔
1603
        self._event.set()
10✔
1604

1605
    def is_set(self) -> bool:
10✔
1606
        return self._event.is_set()
10✔
1607

1608
    async def wait(self) -> None:
10✔
1609
        if self.is_set():
10✔
1610
            await AsyncIOBackend.checkpoint()
10✔
1611
        else:
1612
            await self._event.wait()
10✔
1613

1614
    def statistics(self) -> EventStatistics:
10✔
1615
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
10✔
1616

1617

1618
class CapacityLimiter(BaseCapacityLimiter):
10✔
1619
    _total_tokens: float = 0
10✔
1620

1621
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1622
        return object.__new__(cls)
10✔
1623

1624
    def __init__(self, total_tokens: float):
10✔
1625
        self._borrowers: set[Any] = set()
10✔
1626
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1627
        self.total_tokens = total_tokens
10✔
1628

1629
    async def __aenter__(self) -> None:
10✔
1630
        await self.acquire()
10✔
1631

1632
    async def __aexit__(
10✔
1633
        self,
1634
        exc_type: type[BaseException] | None,
1635
        exc_val: BaseException | None,
1636
        exc_tb: TracebackType | None,
1637
    ) -> None:
1638
        self.release()
10✔
1639

1640
    @property
10✔
1641
    def total_tokens(self) -> float:
10✔
1642
        return self._total_tokens
10✔
1643

1644
    @total_tokens.setter
10✔
1645
    def total_tokens(self, value: float) -> None:
10✔
1646
        if not isinstance(value, int) and not math.isinf(value):
10✔
UNCOV
1647
            raise TypeError("total_tokens must be an int or math.inf")
×
1648
        if value < 1:
10✔
UNCOV
1649
            raise ValueError("total_tokens must be >= 1")
×
1650

1651
        waiters_to_notify = max(value - self._total_tokens, 0)
10✔
1652
        self._total_tokens = value
10✔
1653

1654
        # Notify waiting tasks that they have acquired the limiter
1655
        while self._wait_queue and waiters_to_notify:
10✔
1656
            event = self._wait_queue.popitem(last=False)[1]
10✔
1657
            event.set()
10✔
1658
            waiters_to_notify -= 1
10✔
1659

1660
    @property
10✔
1661
    def borrowed_tokens(self) -> int:
10✔
1662
        return len(self._borrowers)
10✔
1663

1664
    @property
10✔
1665
    def available_tokens(self) -> float:
10✔
1666
        return self._total_tokens - len(self._borrowers)
10✔
1667

1668
    def acquire_nowait(self) -> None:
10✔
1669
        self.acquire_on_behalf_of_nowait(current_task())
×
1670

1671
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1672
        if borrower in self._borrowers:
10✔
1673
            raise RuntimeError(
10✔
1674
                "this borrower is already holding one of this CapacityLimiter's "
1675
                "tokens"
1676
            )
1677

1678
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1679
            raise WouldBlock
10✔
1680

1681
        self._borrowers.add(borrower)
10✔
1682

1683
    async def acquire(self) -> None:
10✔
1684
        return await self.acquire_on_behalf_of(current_task())
10✔
1685

1686
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1687
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1688
        try:
10✔
1689
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1690
        except WouldBlock:
10✔
1691
            event = asyncio.Event()
10✔
1692
            self._wait_queue[borrower] = event
10✔
1693
            try:
10✔
1694
                await event.wait()
10✔
1695
            except BaseException:
×
1696
                self._wait_queue.pop(borrower, None)
×
1697
                raise
×
1698

1699
            self._borrowers.add(borrower)
10✔
1700
        else:
1701
            try:
10✔
1702
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1703
            except BaseException:
10✔
1704
                self.release()
10✔
1705
                raise
10✔
1706

1707
    def release(self) -> None:
10✔
1708
        self.release_on_behalf_of(current_task())
10✔
1709

1710
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1711
        try:
10✔
1712
            self._borrowers.remove(borrower)
10✔
1713
        except KeyError:
10✔
1714
            raise RuntimeError(
10✔
1715
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1716
            ) from None
1717

1718
        # Notify the next task in line if this limiter has free capacity now
1719
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1720
            event = self._wait_queue.popitem(last=False)[1]
10✔
1721
            event.set()
10✔
1722

1723
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1724
        return CapacityLimiterStatistics(
10✔
1725
            self.borrowed_tokens,
1726
            self.total_tokens,
1727
            tuple(self._borrowers),
1728
            len(self._wait_queue),
1729
        )
1730

1731

1732
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1733

1734

1735
#
1736
# Operating system signals
1737
#
1738

1739

1740
class _SignalReceiver:
10✔
1741
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1742
        self._signals = signals
8✔
1743
        self._loop = get_running_loop()
8✔
1744
        self._signal_queue: deque[Signals] = deque()
8✔
1745
        self._future: asyncio.Future = asyncio.Future()
8✔
1746
        self._handled_signals: set[Signals] = set()
8✔
1747

1748
    def _deliver(self, signum: Signals) -> None:
10✔
1749
        self._signal_queue.append(signum)
8✔
1750
        if not self._future.done():
8✔
1751
            self._future.set_result(None)
8✔
1752

1753
    def __enter__(self) -> _SignalReceiver:
10✔
1754
        for sig in set(self._signals):
8✔
1755
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1756
            self._handled_signals.add(sig)
8✔
1757

1758
        return self
8✔
1759

1760
    def __exit__(
10✔
1761
        self,
1762
        exc_type: type[BaseException] | None,
1763
        exc_val: BaseException | None,
1764
        exc_tb: TracebackType | None,
1765
    ) -> bool | None:
1766
        for sig in self._handled_signals:
8✔
1767
            self._loop.remove_signal_handler(sig)
8✔
1768
        return None
8✔
1769

1770
    def __aiter__(self) -> _SignalReceiver:
10✔
1771
        return self
8✔
1772

1773
    async def __anext__(self) -> Signals:
10✔
1774
        await AsyncIOBackend.checkpoint()
8✔
1775
        if not self._signal_queue:
8✔
1776
            self._future = asyncio.Future()
×
1777
            await self._future
×
1778

1779
        return self._signal_queue.popleft()
8✔
1780

1781

1782
#
1783
# Testing and debugging
1784
#
1785

1786

1787
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1788
    task_state = _task_states.get(task)
10✔
1789
    if task_state is None:
10✔
1790
        parent_id = None
10✔
1791
    else:
1792
        parent_id = task_state.parent_id
10✔
1793

1794
    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
10✔
1795

1796

1797
class TestRunner(abc.TestRunner):
10✔
1798
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1799

1800
    def __init__(
10✔
1801
        self,
1802
        *,
1803
        debug: bool | None = None,
1804
        use_uvloop: bool = False,
1805
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
1806
    ) -> None:
1807
        if use_uvloop and loop_factory is None:
10✔
1808
            import uvloop
×
1809

1810
            loop_factory = uvloop.new_event_loop
×
1811

1812
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
1813
        self._exceptions: list[BaseException] = []
10✔
1814
        self._runner_task: asyncio.Task | None = None
10✔
1815

1816
    def __enter__(self) -> TestRunner:
10✔
1817
        self._runner.__enter__()
10✔
1818
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
1819
        return self
10✔
1820

1821
    def __exit__(
10✔
1822
        self,
1823
        exc_type: type[BaseException] | None,
1824
        exc_val: BaseException | None,
1825
        exc_tb: TracebackType | None,
1826
    ) -> None:
1827
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
1828

1829
    def get_loop(self) -> AbstractEventLoop:
10✔
1830
        return self._runner.get_loop()
10✔
1831

1832
    def _exception_handler(
10✔
1833
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1834
    ) -> None:
1835
        if isinstance(context.get("exception"), Exception):
10✔
1836
            self._exceptions.append(context["exception"])
10✔
1837
        else:
1838
            loop.default_exception_handler(context)
10✔
1839

1840
    def _raise_async_exceptions(self) -> None:
10✔
1841
        # Re-raise any exceptions raised in asynchronous callbacks
1842
        if self._exceptions:
10✔
1843
            exceptions, self._exceptions = self._exceptions, []
10✔
1844
            if len(exceptions) == 1:
10✔
1845
                raise exceptions[0]
10✔
1846
            elif exceptions:
×
1847
                raise BaseExceptionGroup(
×
1848
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1849
                )
1850

1851
    @staticmethod
10✔
1852
    async def _run_tests_and_fixtures(
10✔
1853
        receive_stream: MemoryObjectReceiveStream[
1854
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
1855
        ],
1856
    ) -> None:
1857
        with receive_stream:
10✔
1858
            async for coro, future in receive_stream:
10✔
1859
                try:
10✔
1860
                    retval = await coro
10✔
1861
                except BaseException as exc:
10✔
1862
                    if not future.cancelled():
10✔
1863
                        future.set_exception(exc)
10✔
1864
                else:
1865
                    if not future.cancelled():
10✔
1866
                        future.set_result(retval)
10✔
1867

1868
    async def _call_in_runner_task(
10✔
1869
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1870
    ) -> T_Retval:
1871
        if not self._runner_task:
10✔
1872
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
1873
                Tuple[Awaitable[Any], asyncio.Future]
1874
            ](1)
1875
            self._runner_task = self.get_loop().create_task(
10✔
1876
                self._run_tests_and_fixtures(receive_stream)
1877
            )
1878

1879
        coro = func(*args, **kwargs)
10✔
1880
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
1881
        self._send_stream.send_nowait((coro, future))
10✔
1882
        return await future
10✔
1883

1884
    def run_asyncgen_fixture(
10✔
1885
        self,
1886
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1887
        kwargs: dict[str, Any],
1888
    ) -> Iterable[T_Retval]:
1889
        asyncgen = fixture_func(**kwargs)
10✔
1890
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
1891
            self._call_in_runner_task(asyncgen.asend, None)
1892
        )
1893
        self._raise_async_exceptions()
10✔
1894

1895
        yield fixturevalue
10✔
1896

1897
        try:
10✔
1898
            self.get_loop().run_until_complete(
10✔
1899
                self._call_in_runner_task(asyncgen.asend, None)
1900
            )
1901
        except StopAsyncIteration:
10✔
1902
            self._raise_async_exceptions()
10✔
1903
        else:
1904
            self.get_loop().run_until_complete(asyncgen.aclose())
×
1905
            raise RuntimeError("Async generator fixture did not stop")
×
1906

1907
    def run_fixture(
10✔
1908
        self,
1909
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1910
        kwargs: dict[str, Any],
1911
    ) -> T_Retval:
1912
        retval = self.get_loop().run_until_complete(
10✔
1913
            self._call_in_runner_task(fixture_func, **kwargs)
1914
        )
1915
        self._raise_async_exceptions()
10✔
1916
        return retval
10✔
1917

1918
    def run_test(
10✔
1919
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1920
    ) -> None:
1921
        try:
10✔
1922
            self.get_loop().run_until_complete(
10✔
1923
                self._call_in_runner_task(test_func, **kwargs)
1924
            )
1925
        except Exception as exc:
10✔
1926
            self._exceptions.append(exc)
10✔
1927

1928
        self._raise_async_exceptions()
10✔
1929

1930

1931
class AsyncIOBackend(AsyncBackend):
10✔
1932
    @classmethod
10✔
1933
    def run(
10✔
1934
        cls,
1935
        func: Callable[..., Awaitable[T_Retval]],
1936
        args: tuple,
1937
        kwargs: dict[str, Any],
1938
        options: dict[str, Any],
1939
    ) -> T_Retval:
1940
        @wraps(func)
10✔
1941
        async def wrapper() -> T_Retval:
10✔
1942
            task = cast(asyncio.Task, current_task())
10✔
1943
            task.set_name(get_callable_name(func))
10✔
1944
            _task_states[task] = TaskState(None, None)
10✔
1945

1946
            try:
10✔
1947
                return await func(*args)
10✔
1948
            finally:
1949
                del _task_states[task]
10✔
1950

1951
        debug = options.get("debug", False)
10✔
1952
        loop_factory = options.get("loop_factory", None)
10✔
1953
        if loop_factory is None and options.get("use_uvloop", False):
10✔
1954
            import uvloop
7✔
1955

1956
            loop_factory = uvloop.new_event_loop
7✔
1957

1958
        with Runner(debug=debug, loop_factory=loop_factory) as runner:
10✔
1959
            return runner.run(wrapper())
10✔
1960

1961
    @classmethod
10✔
1962
    def current_token(cls) -> object:
10✔
1963
        return get_running_loop()
10✔
1964

1965
    @classmethod
10✔
1966
    def current_time(cls) -> float:
10✔
1967
        return get_running_loop().time()
10✔
1968

1969
    @classmethod
10✔
1970
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
1971
        return CancelledError
10✔
1972

1973
    @classmethod
10✔
1974
    async def checkpoint(cls) -> None:
10✔
1975
        await sleep(0)
10✔
1976

1977
    @classmethod
10✔
1978
    async def checkpoint_if_cancelled(cls) -> None:
10✔
1979
        task = current_task()
10✔
1980
        if task is None:
10✔
1981
            return
×
1982

1983
        try:
10✔
1984
            cancel_scope = _task_states[task].cancel_scope
10✔
1985
        except KeyError:
10✔
1986
            return
10✔
1987

1988
        while cancel_scope:
10✔
1989
            if cancel_scope.cancel_called:
10✔
1990
                await sleep(0)
10✔
1991
            elif cancel_scope.shield:
10✔
1992
                break
10✔
1993
            else:
1994
                cancel_scope = cancel_scope._parent_scope
10✔
1995

1996
    @classmethod
10✔
1997
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
1998
        with CancelScope(shield=True):
10✔
1999
            await sleep(0)
10✔
2000

2001
    @classmethod
10✔
2002
    async def sleep(cls, delay: float) -> None:
10✔
2003
        await sleep(delay)
10✔
2004

2005
    @classmethod
10✔
2006
    def create_cancel_scope(
10✔
2007
        cls, *, deadline: float = math.inf, shield: bool = False
2008
    ) -> CancelScope:
2009
        return CancelScope(deadline=deadline, shield=shield)
10✔
2010

2011
    @classmethod
10✔
2012
    def current_effective_deadline(cls) -> float:
10✔
2013
        try:
10✔
2014
            cancel_scope = _task_states[
10✔
2015
                current_task()  # type: ignore[index]
2016
            ].cancel_scope
2017
        except KeyError:
×
2018
            return math.inf
×
2019

2020
        deadline = math.inf
10✔
2021
        while cancel_scope:
10✔
2022
            deadline = min(deadline, cancel_scope.deadline)
10✔
2023
            if cancel_scope._cancel_called:
10✔
2024
                deadline = -math.inf
10✔
2025
                break
10✔
2026
            elif cancel_scope.shield:
10✔
2027
                break
10✔
2028
            else:
2029
                cancel_scope = cancel_scope._parent_scope
10✔
2030

2031
        return deadline
10✔
2032

2033
    @classmethod
10✔
2034
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2035
        return TaskGroup()
10✔
2036

2037
    @classmethod
10✔
2038
    def create_event(cls) -> abc.Event:
10✔
2039
        return Event()
10✔
2040

2041
    @classmethod
10✔
2042
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2043
        return CapacityLimiter(total_tokens)
10✔
2044

2045
    @classmethod
10✔
2046
    async def run_sync_in_worker_thread(
10✔
2047
        cls,
2048
        func: Callable[..., T_Retval],
2049
        args: tuple[Any, ...],
2050
        abandon_on_cancel: bool = False,
2051
        limiter: abc.CapacityLimiter | None = None,
2052
    ) -> T_Retval:
2053
        await cls.checkpoint()
10✔
2054

2055
        # If this is the first run in this event loop thread, set up the necessary
2056
        # variables
2057
        try:
10✔
2058
            idle_workers = _threadpool_idle_workers.get()
10✔
2059
            workers = _threadpool_workers.get()
10✔
2060
        except LookupError:
10✔
2061
            idle_workers = deque()
10✔
2062
            workers = set()
10✔
2063
            _threadpool_idle_workers.set(idle_workers)
10✔
2064
            _threadpool_workers.set(workers)
10✔
2065

2066
        async with limiter or cls.current_default_thread_limiter():
10✔
2067
            with CancelScope(shield=not abandon_on_cancel) as scope:
10✔
2068
                future: asyncio.Future = asyncio.Future()
10✔
2069
                root_task = find_root_task()
10✔
2070
                if not idle_workers:
10✔
2071
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2072
                    worker.start()
10✔
2073
                    workers.add(worker)
10✔
2074
                    root_task.add_done_callback(worker.stop)
10✔
2075
                else:
2076
                    worker = idle_workers.pop()
10✔
2077

2078
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2079
                    # seconds or longer
2080
                    now = cls.current_time()
10✔
2081
                    while idle_workers:
10✔
2082
                        if (
10✔
2083
                            now - idle_workers[0].idle_since
2084
                            < WorkerThread.MAX_IDLE_TIME
2085
                        ):
2086
                            break
10✔
2087

2088
                        expired_worker = idle_workers.popleft()
×
2089
                        expired_worker.root_task.remove_done_callback(
×
2090
                            expired_worker.stop
2091
                        )
2092
                        expired_worker.stop()
×
2093

2094
                context = copy_context()
10✔
2095
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2096
                if abandon_on_cancel or scope._parent_scope is None:
10✔
2097
                    worker_scope = scope
10✔
2098
                else:
2099
                    worker_scope = scope._parent_scope
10✔
2100

2101
                worker.queue.put_nowait((context, func, args, future, worker_scope))
10✔
2102
                return await future
10✔
2103

2104
    @classmethod
10✔
2105
    def check_cancelled(cls) -> None:
10✔
2106
        scope: CancelScope | None = threadlocals.current_cancel_scope
10✔
2107
        while scope is not None:
10✔
2108
            if scope.cancel_called:
10✔
2109
                raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
10✔
2110

2111
            if scope.shield:
10✔
2112
                return
×
2113

2114
            scope = scope._parent_scope
10✔
2115

2116
    @classmethod
10✔
2117
    def run_async_from_thread(
10✔
2118
        cls,
2119
        func: Callable[..., Awaitable[T_Retval]],
2120
        args: tuple[Any, ...],
2121
        token: object,
2122
    ) -> T_Retval:
2123
        async def task_wrapper(scope: CancelScope) -> T_Retval:
10✔
2124
            __tracebackhide__ = True
10✔
2125
            task = cast(asyncio.Task, current_task())
10✔
2126
            _task_states[task] = TaskState(None, scope)
10✔
2127
            scope._tasks.add(task)
10✔
2128
            try:
10✔
2129
                return await func(*args)
10✔
2130
            except CancelledError as exc:
10✔
2131
                raise concurrent.futures.CancelledError(str(exc)) from None
10✔
2132
            finally:
2133
                scope._tasks.discard(task)
10✔
2134

2135
        loop = cast(AbstractEventLoop, token)
10✔
2136
        context = copy_context()
10✔
2137
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2138
        wrapper = task_wrapper(threadlocals.current_cancel_scope)
10✔
2139
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2140
            asyncio.run_coroutine_threadsafe, wrapper, loop
2141
        )
2142
        return f.result()
10✔
2143

2144
    @classmethod
10✔
2145
    def run_sync_from_thread(
10✔
2146
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2147
    ) -> T_Retval:
2148
        @wraps(func)
10✔
2149
        def wrapper() -> None:
10✔
2150
            try:
10✔
2151
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2152
                f.set_result(func(*args))
10✔
2153
            except BaseException as exc:
10✔
2154
                f.set_exception(exc)
10✔
2155
                if not isinstance(exc, Exception):
10✔
2156
                    raise
×
2157

2158
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2159
        loop = cast(AbstractEventLoop, token)
10✔
2160
        loop.call_soon_threadsafe(wrapper)
10✔
2161
        return f.result()
10✔
2162

2163
    @classmethod
10✔
2164
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2165
        return BlockingPortal()
10✔
2166

2167
    @classmethod
10✔
2168
    async def open_process(
10✔
2169
        cls,
2170
        command: str | bytes | Sequence[str | bytes],
2171
        *,
2172
        shell: bool,
2173
        stdin: int | IO[Any] | None,
2174
        stdout: int | IO[Any] | None,
2175
        stderr: int | IO[Any] | None,
2176
        cwd: str | bytes | PathLike | None = None,
2177
        env: Mapping[str, str] | None = None,
2178
        start_new_session: bool = False,
2179
    ) -> Process:
2180
        await cls.checkpoint()
10✔
2181
        if shell:
10✔
2182
            process = await asyncio.create_subprocess_shell(
10✔
2183
                cast("str | bytes", command),
2184
                stdin=stdin,
2185
                stdout=stdout,
2186
                stderr=stderr,
2187
                cwd=cwd,
2188
                env=env,
2189
                start_new_session=start_new_session,
2190
            )
2191
        else:
2192
            process = await asyncio.create_subprocess_exec(
10✔
2193
                *command,
2194
                stdin=stdin,
2195
                stdout=stdout,
2196
                stderr=stderr,
2197
                cwd=cwd,
2198
                env=env,
2199
                start_new_session=start_new_session,
2200
            )
2201

2202
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
10✔
2203
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
10✔
2204
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
10✔
2205
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
10✔
2206

2207
    @classmethod
10✔
2208
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2209
        create_task(
10✔
2210
            _shutdown_process_pool_on_exit(workers),
2211
            name="AnyIO process pool shutdown task",
2212
        )
2213
        find_root_task().add_done_callback(
10✔
2214
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2215
        )
2216

2217
    @classmethod
10✔
2218
    async def connect_tcp(
10✔
2219
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2220
    ) -> abc.SocketStream:
2221
        transport, protocol = cast(
10✔
2222
            Tuple[asyncio.Transport, StreamProtocol],
2223
            await get_running_loop().create_connection(
2224
                StreamProtocol, host, port, local_addr=local_address
2225
            ),
2226
        )
2227
        transport.pause_reading()
10✔
2228
        return SocketStream(transport, protocol)
10✔
2229

2230
    @classmethod
10✔
2231
    async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
10✔
2232
        await cls.checkpoint()
8✔
2233
        loop = get_running_loop()
8✔
2234
        raw_socket = socket.socket(socket.AF_UNIX)
8✔
2235
        raw_socket.setblocking(False)
8✔
2236
        while True:
5✔
2237
            try:
8✔
2238
                raw_socket.connect(path)
8✔
2239
            except BlockingIOError:
8✔
2240
                f: asyncio.Future = asyncio.Future()
×
2241
                loop.add_writer(raw_socket, f.set_result, None)
×
2242
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2243
                await f
×
2244
            except BaseException:
8✔
2245
                raw_socket.close()
8✔
2246
                raise
8✔
2247
            else:
2248
                return UNIXSocketStream(raw_socket)
8✔
2249

2250
    @classmethod
10✔
2251
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2252
        return TCPSocketListener(sock)
10✔
2253

2254
    @classmethod
10✔
2255
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2256
        return UNIXSocketListener(sock)
8✔
2257

2258
    @classmethod
10✔
2259
    async def create_udp_socket(
10✔
2260
        cls,
2261
        family: AddressFamily,
2262
        local_address: IPSockAddrType | None,
2263
        remote_address: IPSockAddrType | None,
2264
        reuse_port: bool,
2265
    ) -> UDPSocket | ConnectedUDPSocket:
2266
        transport, protocol = await get_running_loop().create_datagram_endpoint(
10✔
2267
            DatagramProtocol,
2268
            local_addr=local_address,
2269
            remote_addr=remote_address,
2270
            family=family,
2271
            reuse_port=reuse_port,
2272
        )
2273
        if protocol.exception:
10✔
2274
            transport.close()
×
2275
            raise protocol.exception
×
2276

2277
        if not remote_address:
10✔
2278
            return UDPSocket(transport, protocol)
10✔
2279
        else:
2280
            return ConnectedUDPSocket(transport, protocol)
10✔
2281

2282
    @classmethod
10✔
2283
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2284
        cls, raw_socket: socket.socket, remote_path: str | bytes | None
2285
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2286
        await cls.checkpoint()
8✔
2287
        loop = get_running_loop()
8✔
2288

2289
        if remote_path:
8✔
2290
            while True:
5✔
2291
                try:
8✔
2292
                    raw_socket.connect(remote_path)
8✔
2293
                except BlockingIOError:
×
2294
                    f: asyncio.Future = asyncio.Future()
×
2295
                    loop.add_writer(raw_socket, f.set_result, None)
×
2296
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2297
                    await f
×
2298
                except BaseException:
×
2299
                    raw_socket.close()
×
2300
                    raise
×
2301
                else:
2302
                    return ConnectedUNIXDatagramSocket(raw_socket)
8✔
2303
        else:
2304
            return UNIXDatagramSocket(raw_socket)
8✔
2305

2306
    @classmethod
10✔
2307
    async def getaddrinfo(
10✔
2308
        cls,
2309
        host: bytes | str | None,
2310
        port: str | int | None,
2311
        *,
2312
        family: int | AddressFamily = 0,
2313
        type: int | SocketKind = 0,
2314
        proto: int = 0,
2315
        flags: int = 0,
2316
    ) -> list[
2317
        tuple[
2318
            AddressFamily,
2319
            SocketKind,
2320
            int,
2321
            str,
2322
            tuple[str, int] | tuple[str, int, int, int],
2323
        ]
2324
    ]:
2325
        return await get_running_loop().getaddrinfo(
10✔
2326
            host, port, family=family, type=type, proto=proto, flags=flags
2327
        )
2328

2329
    @classmethod
10✔
2330
    async def getnameinfo(
10✔
2331
        cls, sockaddr: IPSockAddrType, flags: int = 0
2332
    ) -> tuple[str, str]:
2333
        return await get_running_loop().getnameinfo(sockaddr, flags)
10✔
2334

2335
    @classmethod
10✔
2336
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2337
        await cls.checkpoint()
×
2338
        try:
×
2339
            read_events = _read_events.get()
×
2340
        except LookupError:
×
2341
            read_events = {}
×
2342
            _read_events.set(read_events)
×
2343

2344
        if read_events.get(sock):
×
2345
            raise BusyResourceError("reading from") from None
×
2346

2347
        loop = get_running_loop()
×
2348
        event = read_events[sock] = asyncio.Event()
×
2349
        loop.add_reader(sock, event.set)
×
2350
        try:
×
2351
            await event.wait()
×
2352
        finally:
2353
            if read_events.pop(sock, None) is not None:
×
2354
                loop.remove_reader(sock)
×
2355
                readable = True
×
2356
            else:
2357
                readable = False
×
2358

2359
        if not readable:
×
2360
            raise ClosedResourceError
×
2361

2362
    @classmethod
10✔
2363
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2364
        await cls.checkpoint()
×
2365
        try:
×
2366
            write_events = _write_events.get()
×
2367
        except LookupError:
×
2368
            write_events = {}
×
2369
            _write_events.set(write_events)
×
2370

2371
        if write_events.get(sock):
×
2372
            raise BusyResourceError("writing to") from None
×
2373

2374
        loop = get_running_loop()
×
2375
        event = write_events[sock] = asyncio.Event()
×
2376
        loop.add_writer(sock.fileno(), event.set)
×
2377
        try:
×
2378
            await event.wait()
×
2379
        finally:
2380
            if write_events.pop(sock, None) is not None:
×
2381
                loop.remove_writer(sock)
×
2382
                writable = True
×
2383
            else:
2384
                writable = False
×
2385

2386
        if not writable:
×
2387
            raise ClosedResourceError
×
2388

2389
    @classmethod
10✔
2390
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2391
        try:
10✔
2392
            return _default_thread_limiter.get()
10✔
2393
        except LookupError:
10✔
2394
            limiter = CapacityLimiter(40)
10✔
2395
            _default_thread_limiter.set(limiter)
10✔
2396
            return limiter
10✔
2397

2398
    @classmethod
10✔
2399
    def open_signal_receiver(
10✔
2400
        cls, *signals: Signals
2401
    ) -> ContextManager[AsyncIterator[Signals]]:
2402
        return _SignalReceiver(signals)
8✔
2403

2404
    @classmethod
10✔
2405
    def get_current_task(cls) -> TaskInfo:
10✔
2406
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2407

2408
    @classmethod
10✔
2409
    def get_running_tasks(cls) -> list[TaskInfo]:
10✔
2410
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2411

2412
    @classmethod
10✔
2413
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2414
        await cls.checkpoint()
10✔
2415
        this_task = current_task()
10✔
2416
        while True:
6✔
2417
            for task in all_tasks():
10✔
2418
                if task is this_task:
10✔
2419
                    continue
10✔
2420

2421
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2422
                if waiter is None or waiter.done():
10✔
2423
                    await sleep(0.1)
10✔
2424
                    break
10✔
2425
            else:
2426
                return
10✔
2427

2428
    @classmethod
10✔
2429
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2430
        return TestRunner(**options)
10✔
2431

2432

2433
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc