• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 6959187772

22 Nov 2023 03:09PM UTC coverage: 90.253% (+0.04%) from 90.218%
6959187772

Pull #639

github

web-flow
Merge 8a98b1e3b into f0707cdde
Pull Request #639: Fixed asyncio `CancelScope` not recognizing its own cancellation exception

9 of 10 new or added lines in 1 file covered. (90.0%)

106 existing lines in 1 file now uncovered.

4278 of 4740 relevant lines covered (90.25%)

8.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.69
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextlib import suppress
10✔
25
from contextvars import Context, copy_context
10✔
26
from dataclasses import dataclass
10✔
27
from functools import partial, wraps
10✔
28
from inspect import (
10✔
29
    CORO_RUNNING,
30
    CORO_SUSPENDED,
31
    getcoroutinestate,
32
    iscoroutine,
33
)
34
from io import IOBase
10✔
35
from os import PathLike
10✔
36
from queue import Queue
10✔
37
from signal import Signals
10✔
38
from socket import AddressFamily, SocketKind
10✔
39
from threading import Thread
10✔
40
from types import TracebackType
10✔
41
from typing import (
10✔
42
    IO,
43
    Any,
44
    AsyncGenerator,
45
    Awaitable,
46
    Callable,
47
    Collection,
48
    ContextManager,
49
    Coroutine,
50
    Mapping,
51
    Optional,
52
    Sequence,
53
    Tuple,
54
    TypeVar,
55
    cast,
56
)
57
from weakref import WeakKeyDictionary
10✔
58

59
import sniffio
10✔
60

61
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
62
from .._core._eventloop import claim_worker_thread
10✔
63
from .._core._exceptions import (
10✔
64
    BrokenResourceError,
65
    BusyResourceError,
66
    ClosedResourceError,
67
    EndOfStream,
68
    WouldBlock,
69
)
70
from .._core._sockets import convert_ipv6_sockaddr
10✔
71
from .._core._streams import create_memory_object_stream
10✔
72
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
73
from .._core._synchronization import Event as BaseEvent
10✔
74
from .._core._synchronization import ResourceGuard
10✔
75
from .._core._tasks import CancelScope as BaseCancelScope
10✔
76
from ..abc import (
10✔
77
    AsyncBackend,
78
    IPSockAddrType,
79
    SocketListener,
80
    UDPPacketType,
81
    UNIXDatagramPacketType,
82
)
83
from ..lowlevel import RunVar
10✔
84
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
85

86
if sys.version_info >= (3, 11):
10✔
87
    from asyncio import Runner
5✔
88
else:
89
    import contextvars
6✔
90
    import enum
6✔
91
    import signal
6✔
92
    from asyncio import coroutines, events, exceptions, tasks
6✔
93

94
    from exceptiongroup import BaseExceptionGroup
6✔
95

96
    class _State(enum.Enum):
6✔
97
        CREATED = "created"
6✔
98
        INITIALIZED = "initialized"
6✔
99
        CLOSED = "closed"
6✔
100

101
    class Runner:
6✔
102
        # Copied from CPython 3.11
103
        def __init__(
6✔
104
            self,
105
            *,
106
            debug: bool | None = None,
107
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
108
        ):
109
            self._state = _State.CREATED
6✔
110
            self._debug = debug
6✔
111
            self._loop_factory = loop_factory
6✔
112
            self._loop: AbstractEventLoop | None = None
6✔
113
            self._context = None
6✔
114
            self._interrupt_count = 0
6✔
115
            self._set_event_loop = False
6✔
116

117
        def __enter__(self) -> Runner:
6✔
118
            self._lazy_init()
6✔
119
            return self
6✔
120

121
        def __exit__(
6✔
122
            self,
123
            exc_type: type[BaseException],
124
            exc_val: BaseException,
125
            exc_tb: TracebackType,
126
        ) -> None:
127
            self.close()
6✔
128

129
        def close(self) -> None:
6✔
130
            """Shutdown and close event loop."""
131
            if self._state is not _State.INITIALIZED:
6✔
132
                return
×
133
            try:
6✔
134
                loop = self._loop
6✔
135
                _cancel_all_tasks(loop)
6✔
136
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
137
                if hasattr(loop, "shutdown_default_executor"):
6✔
138
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
139
                else:
140
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
141
            finally:
142
                if self._set_event_loop:
6✔
143
                    events.set_event_loop(None)
6✔
144
                loop.close()
6✔
145
                self._loop = None
6✔
146
                self._state = _State.CLOSED
6✔
147

148
        def get_loop(self) -> AbstractEventLoop:
6✔
149
            """Return embedded event loop."""
150
            self._lazy_init()
6✔
151
            return self._loop
6✔
152

153
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
154
            """Run a coroutine inside the embedded event loop."""
155
            if not coroutines.iscoroutine(coro):
×
156
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
157

158
            if events._get_running_loop() is not None:
×
159
                # fail fast with short traceback
160
                raise RuntimeError(
×
161
                    "Runner.run() cannot be called from a running event loop"
162
                )
163

164
            self._lazy_init()
×
165

166
            if context is None:
×
167
                context = self._context
×
168
            task = self._loop.create_task(coro, context=context)
×
169

170
            if (
×
171
                threading.current_thread() is threading.main_thread()
172
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
173
            ):
174
                sigint_handler = partial(self._on_sigint, main_task=task)
×
175
                try:
×
176
                    signal.signal(signal.SIGINT, sigint_handler)
×
177
                except ValueError:
×
178
                    # `signal.signal` may throw if `threading.main_thread` does
179
                    # not support signals (e.g. embedded interpreter with signals
180
                    # not registered - see gh-91880)
181
                    sigint_handler = None
×
182
            else:
183
                sigint_handler = None
×
184

185
            self._interrupt_count = 0
×
186
            try:
×
187
                return self._loop.run_until_complete(task)
×
188
            except exceptions.CancelledError:
×
189
                if self._interrupt_count > 0:
×
190
                    uncancel = getattr(task, "uncancel", None)
×
191
                    if uncancel is not None and uncancel() == 0:
×
192
                        raise KeyboardInterrupt()
×
193
                raise  # CancelledError
1✔
194
            finally:
195
                if (
×
196
                    sigint_handler is not None
197
                    and signal.getsignal(signal.SIGINT) is sigint_handler
198
                ):
199
                    signal.signal(signal.SIGINT, signal.default_int_handler)
×
200

201
        def _lazy_init(self) -> None:
6✔
202
            if self._state is _State.CLOSED:
6✔
203
                raise RuntimeError("Runner is closed")
×
204
            if self._state is _State.INITIALIZED:
6✔
205
                return
6✔
206
            if self._loop_factory is None:
6✔
207
                self._loop = events.new_event_loop()
6✔
208
                if not self._set_event_loop:
6✔
209
                    # Call set_event_loop only once to avoid calling
210
                    # attach_loop multiple times on child watchers
211
                    events.set_event_loop(self._loop)
6✔
212
                    self._set_event_loop = True
6✔
213
            else:
214
                self._loop = self._loop_factory()
4✔
215
            if self._debug is not None:
6✔
216
                self._loop.set_debug(self._debug)
6✔
217
            self._context = contextvars.copy_context()
6✔
218
            self._state = _State.INITIALIZED
6✔
219

220
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
221
            self._interrupt_count += 1
×
222
            if self._interrupt_count == 1 and not main_task.done():
×
223
                main_task.cancel()
×
224
                # wakeup loop if it is blocked by select() with long timeout
225
                self._loop.call_soon_threadsafe(lambda: None)
×
226
                return
×
227
            raise KeyboardInterrupt()
×
228

229
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
230
        to_cancel = tasks.all_tasks(loop)
6✔
231
        if not to_cancel:
6✔
232
            return
×
233

234
        for task in to_cancel:
6✔
235
            task.cancel()
6✔
236

237
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
238

239
        for task in to_cancel:
6✔
240
            if task.cancelled():
6✔
241
                continue
6✔
242
            if task.exception() is not None:
5✔
243
                loop.call_exception_handler(
2✔
244
                    {
245
                        "message": "unhandled exception during asyncio.run() shutdown",
246
                        "exception": task.exception(),
247
                        "task": task,
248
                    }
249
                )
250

251
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
252
        """Schedule the shutdown of the default executor."""
253

254
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
255
            try:
3✔
256
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
257
                loop.call_soon_threadsafe(future.set_result, None)
3✔
258
            except Exception as ex:
×
259
                loop.call_soon_threadsafe(future.set_exception, ex)
×
260

261
        loop._executor_shutdown_called = True
3✔
262
        if loop._default_executor is None:
3✔
263
            return
3✔
264
        future = loop.create_future()
3✔
265
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
266
        thread.start()
3✔
267
        try:
3✔
268
            await future
3✔
269
        finally:
270
            thread.join()
3✔
271

272

273
T_Retval = TypeVar("T_Retval")
10✔
274
T_contra = TypeVar("T_contra", contravariant=True)
10✔
275

276
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
277

278

279
def find_root_task() -> asyncio.Task:
10✔
280
    root_task = _root_task.get(None)
10✔
281
    if root_task is not None and not root_task.done():
10✔
282
        return root_task
10✔
283

284
    # Look for a task that has been started via run_until_complete()
285
    for task in all_tasks():
10✔
286
        if task._callbacks and not task.done():
10✔
287
            callbacks = [cb for cb, context in task._callbacks]
10✔
288
            for cb in callbacks:
10✔
289
                if (
10✔
290
                    cb is _run_until_complete_cb
291
                    or getattr(cb, "__module__", None) == "uvloop.loop"
292
                ):
293
                    _root_task.set(task)
10✔
294
                    return task
10✔
295

296
    # Look up the topmost task in the AnyIO task tree, if possible
297
    task = cast(asyncio.Task, current_task())
9✔
298
    state = _task_states.get(task)
9✔
299
    if state:
9✔
300
        cancel_scope = state.cancel_scope
9✔
301
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
302
            cancel_scope = cancel_scope._parent_scope
×
303

304
        if cancel_scope is not None:
9✔
305
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
306

307
    return task
×
308

309

310
def get_callable_name(func: Callable) -> str:
10✔
311
    module = getattr(func, "__module__", None)
10✔
312
    qualname = getattr(func, "__qualname__", None)
10✔
313
    return ".".join([x for x in (module, qualname) if x])
10✔
314

315

316
#
317
# Event loop
318
#
319

320
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
10✔
321

322

323
def _task_started(task: asyncio.Task) -> bool:
10✔
324
    """Return ``True`` if the task has been started and has not finished."""
325
    try:
10✔
326
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
327
    except AttributeError:
×
328
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
329
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
330

331

332
#
333
# Timeouts and cancellation
334
#
335

336

337
class CancelScope(BaseCancelScope):
10✔
338
    def __new__(
10✔
339
        cls, *, deadline: float = math.inf, shield: bool = False
340
    ) -> CancelScope:
341
        return object.__new__(cls)
10✔
342

343
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
344
        self._deadline = deadline
10✔
345
        self._shield = shield
10✔
346
        self._parent_scope: CancelScope | None = None
10✔
347
        self._cancel_called = False
10✔
348
        self._cancelled_caught = False
10✔
349
        self._active = False
10✔
350
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
351
        self._cancel_handle: asyncio.Handle | None = None
10✔
352
        self._tasks: set[asyncio.Task] = set()
10✔
353
        self._host_task: asyncio.Task | None = None
10✔
354
        self._cancel_calls: int = 0
10✔
355
        self._cancelling: int | None = None
10✔
356

357
    def __enter__(self) -> CancelScope:
10✔
358
        if self._active:
10✔
359
            raise RuntimeError(
×
360
                "Each CancelScope may only be used for a single 'with' block"
361
            )
362

363
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
364
        self._tasks.add(host_task)
10✔
365
        try:
10✔
366
            task_state = _task_states[host_task]
10✔
367
        except KeyError:
10✔
368
            task_state = TaskState(None, self)
10✔
369
            _task_states[host_task] = task_state
10✔
370
        else:
371
            self._parent_scope = task_state.cancel_scope
10✔
372
            task_state.cancel_scope = self
10✔
373

374
        self._timeout()
10✔
375
        self._active = True
10✔
376
        if sys.version_info >= (3, 11):
10✔
377
            self._cancelling = self._host_task.cancelling()
4✔
378

379
        # Start cancelling the host task if the scope was cancelled before entering
380
        if self._cancel_called:
10✔
381
            self._deliver_cancellation()
10✔
382

383
        return self
10✔
384

385
    def __exit__(
10✔
386
        self,
387
        exc_type: type[BaseException] | None,
388
        exc_val: BaseException | None,
389
        exc_tb: TracebackType | None,
390
    ) -> bool | None:
391
        if not self._active:
10✔
392
            raise RuntimeError("This cancel scope is not active")
9✔
393
        if current_task() is not self._host_task:
10✔
394
            raise RuntimeError(
9✔
395
                "Attempted to exit cancel scope in a different task than it was "
396
                "entered in"
397
            )
398

399
        assert self._host_task is not None
10✔
400
        host_task_state = _task_states.get(self._host_task)
10✔
401
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
402
            raise RuntimeError(
9✔
403
                "Attempted to exit a cancel scope that isn't the current tasks's "
404
                "current cancel scope"
405
            )
406

407
        self._active = False
10✔
408
        if self._timeout_handle:
10✔
409
            self._timeout_handle.cancel()
10✔
410
            self._timeout_handle = None
10✔
411

412
        self._tasks.remove(self._host_task)
10✔
413

414
        host_task_state.cancel_scope = self._parent_scope
10✔
415

416
        # Restart the cancellation effort in the farthest directly cancelled parent
417
        # scope if this one was shielded
418
        if self._shield:
10✔
419
            self._deliver_cancellation_to_parent()
10✔
420

421
        if self._cancel_called:
10✔
422
            if isinstance(exc_val, CancelledError):
10✔
423
                self._cancelled_caught = self._uncancel(exc_val)
10✔
424
            elif isinstance(exc_val, BaseExceptionGroup) and (
10✔
425
                excgrp := exc_val.split(CancelledError)[0]
426
            ):
427
                for exc in excgrp.exceptions:
9✔
428
                    if isinstance(exc, CancelledError):
9✔
429
                        self._cancelled_caught = self._uncancel(exc)
9✔
430
                        if self._cancelled_caught:
9✔
431
                            break
9✔
432

433
            return self._cancelled_caught
10✔
434

435
        return None
10✔
436

437
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
10✔
438
        if sys.version_info < (3, 9) or self._host_task is None:
10✔
439
            self._cancel_calls = 0
3✔
440
            return True
3✔
441

442
        # Undo all cancellations done by this scope
443
        if self._cancelling is not None:
7✔
444
            while self._cancel_calls:
4✔
445
                self._cancel_calls -= 1
4✔
446
                if self._host_task.uncancel() <= self._cancelling:
4✔
447
                    return True
4✔
448

449
        self._cancel_calls = 0
7✔
450
        return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
7✔
451

452
    def _timeout(self) -> None:
10✔
453
        if self._deadline != math.inf:
10✔
454
            loop = get_running_loop()
10✔
455
            if loop.time() >= self._deadline:
10✔
456
                self.cancel()
10✔
457
            else:
458
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
459

460
    def _deliver_cancellation(self) -> None:
10✔
461
        """
462
        Deliver cancellation to directly contained tasks and nested cancel scopes.
463

464
        Schedule another run at the end if we still have tasks eligible for
465
        cancellation.
466
        """
467
        should_retry = False
10✔
468
        current = current_task()
10✔
469
        for task in self._tasks:
10✔
470
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
471
                continue
9✔
472

473
            # The task is eligible for cancellation if it has started and is not in a
474
            # cancel scope shielded from this one
475
            cancel_scope = _task_states[task].cancel_scope
10✔
476
            while cancel_scope is not self:
10✔
477
                if cancel_scope is None or cancel_scope._shield:
10✔
478
                    break
10✔
479
                else:
480
                    cancel_scope = cancel_scope._parent_scope
10✔
481
            else:
482
                should_retry = True
10✔
483
                if task is not current and (
10✔
484
                    task is self._host_task or _task_started(task)
485
                ):
486
                    waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
487
                    if not isinstance(waiter, asyncio.Future) or not waiter.done():
10✔
488
                        self._cancel_calls += 1
10✔
489
                        if sys.version_info >= (3, 9):
10✔
490
                            task.cancel(f"Cancelled by cancel scope {id(self):x}")
7✔
491
                        else:
492
                            task.cancel()
3✔
493

494
        # Schedule another callback if there are still tasks left
495
        if should_retry:
10✔
496
            self._cancel_handle = get_running_loop().call_soon(
10✔
497
                self._deliver_cancellation
498
            )
499
        else:
500
            self._cancel_handle = None
10✔
501

502
    def _deliver_cancellation_to_parent(self) -> None:
10✔
503
        """Start cancellation effort in the farthest directly cancelled parent scope"""
504
        scope = self._parent_scope
10✔
505
        scope_to_cancel: CancelScope | None = None
10✔
506
        while scope is not None:
10✔
507
            if scope._cancel_called and scope._cancel_handle is None:
10✔
508
                scope_to_cancel = scope
10✔
509

510
            # No point in looking beyond any shielded scope
511
            if scope._shield:
10✔
512
                break
9✔
513

514
            scope = scope._parent_scope
10✔
515

516
        if scope_to_cancel is not None:
10✔
517
            scope_to_cancel._deliver_cancellation()
10✔
518

519
    def _parent_cancelled(self) -> bool:
10✔
520
        # Check whether any parent has been cancelled
521
        cancel_scope = self._parent_scope
×
UNCOV
522
        while cancel_scope is not None and not cancel_scope._shield:
×
523
            if cancel_scope._cancel_called:
×
UNCOV
524
                return True
×
525
            else:
UNCOV
526
                cancel_scope = cancel_scope._parent_scope
×
527

UNCOV
528
        return False
×
529

530
    def cancel(self) -> None:
10✔
531
        if not self._cancel_called:
10✔
532
            if self._timeout_handle:
10✔
533
                self._timeout_handle.cancel()
10✔
534
                self._timeout_handle = None
10✔
535

536
            self._cancel_called = True
10✔
537
            if self._host_task is not None:
10✔
538
                self._deliver_cancellation()
10✔
539

540
    @property
10✔
541
    def deadline(self) -> float:
10✔
542
        return self._deadline
9✔
543

544
    @deadline.setter
10✔
545
    def deadline(self, value: float) -> None:
10✔
546
        self._deadline = float(value)
9✔
547
        if self._timeout_handle is not None:
9✔
548
            self._timeout_handle.cancel()
9✔
549
            self._timeout_handle = None
9✔
550

551
        if self._active and not self._cancel_called:
9✔
552
            self._timeout()
9✔
553

554
    @property
10✔
555
    def cancel_called(self) -> bool:
10✔
556
        return self._cancel_called
10✔
557

558
    @property
10✔
559
    def cancelled_caught(self) -> bool:
10✔
560
        return self._cancelled_caught
10✔
561

562
    @property
10✔
563
    def shield(self) -> bool:
10✔
564
        return self._shield
10✔
565

566
    @shield.setter
10✔
567
    def shield(self, value: bool) -> None:
10✔
568
        if self._shield != value:
9✔
569
            self._shield = value
9✔
570
            if not value:
9✔
571
                self._deliver_cancellation_to_parent()
9✔
572

573

574
#
575
# Task states
576
#
577

578

579
class TaskState:
10✔
580
    """
581
    Encapsulates auxiliary task information that cannot be added to the Task instance
582
    itself because there are no guarantees about its implementation.
583
    """
584

585
    __slots__ = "parent_id", "cancel_scope"
10✔
586

587
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
588
        self.parent_id = parent_id
10✔
589
        self.cancel_scope = cancel_scope
10✔
590

591

592
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
593

594

595
#
596
# Task groups
597
#
598

599

600
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
601
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
602
        self._future = future
10✔
603
        self._parent_id = parent_id
10✔
604

605
    def started(self, value: T_contra | None = None) -> None:
10✔
606
        try:
10✔
607
            self._future.set_result(value)
10✔
608
        except asyncio.InvalidStateError:
9✔
609
            raise RuntimeError(
9✔
610
                "called 'started' twice on the same task status"
611
            ) from None
612

613
        task = cast(asyncio.Task, current_task())
10✔
614
        _task_states[task].parent_id = self._parent_id
10✔
615

616

617
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
618
    exceptions = list(excgroup.exceptions)
×
NEW
619
    modified = False
×
UNCOV
620
    for i, exc in enumerate(exceptions):
×
UNCOV
621
        if isinstance(exc, BaseExceptionGroup):
×
UNCOV
622
            new_exc = collapse_exception_group(exc)
×
UNCOV
623
            if new_exc is not exc:
×
UNCOV
624
                modified = True
×
UNCOV
625
                exceptions[i] = new_exc
×
626

UNCOV
627
    if len(exceptions) == 1:
×
UNCOV
628
        return exceptions[0]
×
UNCOV
629
    elif modified:
×
UNCOV
630
        return excgroup.derive(exceptions)
×
631
    else:
UNCOV
632
        return excgroup
×
633

634

635
class TaskGroup(abc.TaskGroup):
10✔
636
    def __init__(self) -> None:
10✔
637
        self.cancel_scope: CancelScope = CancelScope()
10✔
638
        self._active = False
10✔
639
        self._exceptions: list[BaseException] = []
10✔
640

641
    async def __aenter__(self) -> TaskGroup:
10✔
642
        self.cancel_scope.__enter__()
10✔
643
        self._active = True
10✔
644
        return self
10✔
645

646
    async def __aexit__(
10✔
647
        self,
648
        exc_type: type[BaseException] | None,
649
        exc_val: BaseException | None,
650
        exc_tb: TracebackType | None,
651
    ) -> bool | None:
652
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
653
        if exc_val is not None:
10✔
654
            self.cancel_scope.cancel()
10✔
655
            if not isinstance(exc_val, CancelledError):
10✔
656
                self._exceptions.append(exc_val)
10✔
657

658
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
10✔
659
        while self.cancel_scope._tasks:
10✔
660
            try:
10✔
661
                await asyncio.wait(self.cancel_scope._tasks)
10✔
662
            except CancelledError as exc:
10✔
663
                # This task was cancelled natively; reraise the CancelledError later
664
                # unless this task was already interrupted by another exception
665
                self.cancel_scope.cancel()
10✔
666
                if cancelled_exc_while_waiting_tasks is None:
10✔
667
                    cancelled_exc_while_waiting_tasks = exc
10✔
668

669
        self._active = False
10✔
670
        if self._exceptions:
10✔
671
            raise BaseExceptionGroup(
10✔
672
                "unhandled errors in a TaskGroup", self._exceptions
673
            )
674

675
        # Raise the CancelledError received while waiting for child tasks to exit,
676
        # unless the context manager itself was previously exited with another
677
        # exception, or if any of the  child tasks raised an exception other than
678
        # CancelledError
679
        if cancelled_exc_while_waiting_tasks:
10✔
680
            if exc_val is None or ignore_exception:
10✔
681
                raise cancelled_exc_while_waiting_tasks
10✔
682

683
        return ignore_exception
10✔
684

685
    def _spawn(
10✔
686
        self,
687
        func: Callable[..., Awaitable[Any]],
688
        args: tuple,
689
        name: object,
690
        task_status_future: asyncio.Future | None = None,
691
    ) -> asyncio.Task:
692
        def task_done(_task: asyncio.Task) -> None:
10✔
693
            assert _task in self.cancel_scope._tasks
10✔
694
            self.cancel_scope._tasks.remove(_task)
10✔
695
            del _task_states[_task]
10✔
696

697
            try:
10✔
698
                exc = _task.exception()
10✔
699
            except CancelledError as e:
10✔
700
                while isinstance(e.__context__, CancelledError):
10✔
701
                    e = e.__context__
3✔
702

703
                exc = e
10✔
704

705
            if exc is not None:
10✔
706
                if task_status_future is None or task_status_future.done():
10✔
707
                    if not isinstance(exc, CancelledError):
10✔
708
                        self._exceptions.append(exc)
10✔
709

710
                    self.cancel_scope.cancel()
10✔
711
                else:
712
                    task_status_future.set_exception(exc)
9✔
713
            elif task_status_future is not None and not task_status_future.done():
10✔
714
                task_status_future.set_exception(
9✔
715
                    RuntimeError("Child exited without calling task_status.started()")
716
                )
717

718
        if not self._active:
10✔
719
            raise RuntimeError(
10✔
720
                "This task group is not active; no new tasks can be started."
721
            )
722

723
        kwargs = {}
10✔
724
        if task_status_future:
10✔
725
            parent_id = id(current_task())
10✔
726
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
727
                task_status_future, id(self.cancel_scope._host_task)
728
            )
729
        else:
730
            parent_id = id(self.cancel_scope._host_task)
10✔
731

732
        coro = func(*args, **kwargs)
10✔
733
        if not iscoroutine(coro):
10✔
734
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
735
            raise TypeError(
9✔
736
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
737
                f"the return value ({coro!r}) is not a coroutine object"
738
            )
739

740
        name = get_callable_name(func) if name is None else str(name)
10✔
741
        task = create_task(coro, name=name)
10✔
742
        task.add_done_callback(task_done)
10✔
743

744
        # Make the spawned task inherit the task group's cancel scope
745
        _task_states[task] = TaskState(
10✔
746
            parent_id=parent_id, cancel_scope=self.cancel_scope
747
        )
748
        self.cancel_scope._tasks.add(task)
10✔
749
        return task
10✔
750

751
    def start_soon(
10✔
752
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
753
    ) -> None:
754
        self._spawn(func, args, name)
10✔
755

756
    async def start(
10✔
757
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
758
    ) -> None:
759
        future: asyncio.Future = asyncio.Future()
10✔
760
        task = self._spawn(func, args, name, future)
10✔
761

762
        # If the task raises an exception after sending a start value without a switch
763
        # point between, the task group is cancelled and this method never proceeds to
764
        # process the completed future. That's why we have to have a shielded cancel
765
        # scope here.
766
        try:
10✔
767
            return await future
10✔
768
        except CancelledError:
9✔
769
            # Cancel the task and wait for it to exit before returning
770
            task.cancel()
9✔
771
            with CancelScope(shield=True), suppress(CancelledError):
9✔
772
                await task
9✔
773

774
            raise
9✔
775

776

777
#
778
# Threads
779
#
780

781
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
782

783

784
class WorkerThread(Thread):
10✔
785
    MAX_IDLE_TIME = 10  # seconds
10✔
786

787
    def __init__(
10✔
788
        self,
789
        root_task: asyncio.Task,
790
        workers: set[WorkerThread],
791
        idle_workers: deque[WorkerThread],
792
    ):
793
        super().__init__(name="AnyIO worker thread")
10✔
794
        self.root_task = root_task
10✔
795
        self.workers = workers
10✔
796
        self.idle_workers = idle_workers
10✔
797
        self.loop = root_task._loop
10✔
798
        self.queue: Queue[
10✔
799
            tuple[Context, Callable, tuple, asyncio.Future] | None
800
        ] = Queue(2)
801
        self.idle_since = AsyncIOBackend.current_time()
10✔
802
        self.stopping = False
10✔
803

804
    def _report_result(
10✔
805
        self, future: asyncio.Future, result: Any, exc: BaseException | None
806
    ) -> None:
807
        self.idle_since = AsyncIOBackend.current_time()
10✔
808
        if not self.stopping:
10✔
809
            self.idle_workers.append(self)
10✔
810

811
        if not future.cancelled():
10✔
812
            if exc is not None:
10✔
813
                if isinstance(exc, StopIteration):
10✔
814
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
815
                    new_exc.__cause__ = exc
9✔
816
                    exc = new_exc
9✔
817

818
                future.set_exception(exc)
10✔
819
            else:
820
                future.set_result(result)
10✔
821

822
    def run(self) -> None:
10✔
823
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
824
            while True:
6✔
825
                item = self.queue.get()
10✔
826
                if item is None:
10✔
827
                    # Shutdown command received
828
                    return
10✔
829

830
                context, func, args, future = item
10✔
831
                if not future.cancelled():
10✔
832
                    result = None
10✔
833
                    exception: BaseException | None = None
10✔
834
                    try:
10✔
835
                        result = context.run(func, *args)
10✔
836
                    except BaseException as exc:
10✔
837
                        exception = exc
10✔
838

839
                    if not self.loop.is_closed():
10✔
840
                        self.loop.call_soon_threadsafe(
10✔
841
                            self._report_result, future, result, exception
842
                        )
843

844
                self.queue.task_done()
10✔
845

846
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
847
        self.stopping = True
10✔
848
        self.queue.put_nowait(None)
10✔
849
        self.workers.discard(self)
10✔
850
        try:
10✔
851
            self.idle_workers.remove(self)
10✔
852
        except ValueError:
9✔
853
            pass
9✔
854

855

856
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
857
    "_threadpool_idle_workers"
858
)
859
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
860

861

862
class BlockingPortal(abc.BlockingPortal):
10✔
863
    def __new__(cls) -> BlockingPortal:
10✔
864
        return object.__new__(cls)
10✔
865

866
    def __init__(self) -> None:
10✔
867
        super().__init__()
10✔
868
        self._loop = get_running_loop()
10✔
869

870
    def _spawn_task_from_thread(
10✔
871
        self,
872
        func: Callable,
873
        args: tuple[Any, ...],
874
        kwargs: dict[str, Any],
875
        name: object,
876
        future: Future,
877
    ) -> None:
878
        AsyncIOBackend.run_sync_from_thread(
10✔
879
            partial(self._task_group.start_soon, name=name),
880
            (self._call_func, func, args, kwargs, future),
881
            self._loop,
882
        )
883

884

885
#
886
# Subprocesses
887
#
888

889

890
@dataclass(eq=False)
10✔
891
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
892
    _stream: asyncio.StreamReader
10✔
893

894
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
895
        data = await self._stream.read(max_bytes)
9✔
896
        if data:
9✔
897
            return data
9✔
898
        else:
899
            raise EndOfStream
9✔
900

901
    async def aclose(self) -> None:
10✔
902
        self._stream.feed_eof()
9✔
903

904

905
@dataclass(eq=False)
10✔
906
class StreamWriterWrapper(abc.ByteSendStream):
10✔
907
    _stream: asyncio.StreamWriter
10✔
908

909
    async def send(self, item: bytes) -> None:
10✔
910
        self._stream.write(item)
9✔
911
        await self._stream.drain()
9✔
912

913
    async def aclose(self) -> None:
10✔
914
        self._stream.close()
9✔
915

916

917
@dataclass(eq=False)
10✔
918
class Process(abc.Process):
10✔
919
    _process: asyncio.subprocess.Process
10✔
920
    _stdin: StreamWriterWrapper | None
10✔
921
    _stdout: StreamReaderWrapper | None
10✔
922
    _stderr: StreamReaderWrapper | None
10✔
923

924
    async def aclose(self) -> None:
10✔
925
        if self._stdin:
9✔
926
            await self._stdin.aclose()
9✔
927
        if self._stdout:
9✔
928
            await self._stdout.aclose()
9✔
929
        if self._stderr:
9✔
930
            await self._stderr.aclose()
9✔
931

932
        await self.wait()
9✔
933

934
    async def wait(self) -> int:
10✔
935
        return await self._process.wait()
9✔
936

937
    def terminate(self) -> None:
10✔
938
        self._process.terminate()
7✔
939

940
    def kill(self) -> None:
10✔
941
        self._process.kill()
9✔
942

943
    def send_signal(self, signal: int) -> None:
10✔
UNCOV
944
        self._process.send_signal(signal)
×
945

946
    @property
10✔
947
    def pid(self) -> int:
10✔
UNCOV
948
        return self._process.pid
×
949

950
    @property
10✔
951
    def returncode(self) -> int | None:
10✔
952
        return self._process.returncode
9✔
953

954
    @property
10✔
955
    def stdin(self) -> abc.ByteSendStream | None:
10✔
956
        return self._stdin
9✔
957

958
    @property
10✔
959
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
960
        return self._stdout
9✔
961

962
    @property
10✔
963
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
964
        return self._stderr
9✔
965

966

967
def _forcibly_shutdown_process_pool_on_exit(
10✔
968
    workers: set[Process], _task: object
969
) -> None:
970
    """
971
    Forcibly shuts down worker processes belonging to this event loop."""
972
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
973
    if sys.version_info < (3, 12):
9✔
974
        try:
8✔
975
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
976
        except NotImplementedError:
2✔
977
            pass
2✔
978

979
    # Close as much as possible (w/o async/await) to avoid warnings
980
    for process in workers:
9✔
981
        if process.returncode is None:
9✔
982
            continue
9✔
983

UNCOV
984
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
985
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
986
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
987
        process.kill()
×
UNCOV
988
        if child_watcher:
×
UNCOV
989
            child_watcher.remove_child_handler(process.pid)
×
990

991

992
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
993
    """
994
    Shuts down worker processes belonging to this event loop.
995

996
    NOTE: this only works when the event loop was started using asyncio.run() or
997
    anyio.run().
998

999
    """
1000
    process: abc.Process
1001
    try:
9✔
1002
        await sleep(math.inf)
9✔
1003
    except asyncio.CancelledError:
9✔
1004
        for process in workers:
9✔
1005
            if process.returncode is None:
9✔
1006
                process.kill()
9✔
1007

1008
        for process in workers:
9✔
1009
            await process.aclose()
9✔
1010

1011

1012
#
1013
# Sockets and networking
1014
#
1015

1016

1017
class StreamProtocol(asyncio.Protocol):
10✔
1018
    read_queue: deque[bytes]
10✔
1019
    read_event: asyncio.Event
10✔
1020
    write_event: asyncio.Event
10✔
1021
    exception: Exception | None = None
10✔
1022

1023
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1024
        self.read_queue = deque()
10✔
1025
        self.read_event = asyncio.Event()
10✔
1026
        self.write_event = asyncio.Event()
10✔
1027
        self.write_event.set()
10✔
1028
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1029

1030
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1031
        if exc:
10✔
1032
            self.exception = BrokenResourceError()
10✔
1033
            self.exception.__cause__ = exc
10✔
1034

1035
        self.read_event.set()
10✔
1036
        self.write_event.set()
10✔
1037

1038
    def data_received(self, data: bytes) -> None:
10✔
1039
        self.read_queue.append(data)
10✔
1040
        self.read_event.set()
10✔
1041

1042
    def eof_received(self) -> bool | None:
10✔
1043
        self.read_event.set()
10✔
1044
        return True
10✔
1045

1046
    def pause_writing(self) -> None:
10✔
1047
        self.write_event = asyncio.Event()
10✔
1048

1049
    def resume_writing(self) -> None:
10✔
1050
        self.write_event.set()
1✔
1051

1052

1053
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1054
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1055
    read_event: asyncio.Event
10✔
1056
    write_event: asyncio.Event
10✔
1057
    exception: Exception | None = None
10✔
1058

1059
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1060
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1061
        self.read_event = asyncio.Event()
9✔
1062
        self.write_event = asyncio.Event()
9✔
1063
        self.write_event.set()
9✔
1064

1065
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1066
        self.read_event.set()
9✔
1067
        self.write_event.set()
9✔
1068

1069
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1070
        addr = convert_ipv6_sockaddr(addr)
9✔
1071
        self.read_queue.append((data, addr))
9✔
1072
        self.read_event.set()
9✔
1073

1074
    def error_received(self, exc: Exception) -> None:
10✔
UNCOV
1075
        self.exception = exc
×
1076

1077
    def pause_writing(self) -> None:
10✔
UNCOV
1078
        self.write_event.clear()
×
1079

1080
    def resume_writing(self) -> None:
10✔
UNCOV
1081
        self.write_event.set()
×
1082

1083

1084
class SocketStream(abc.SocketStream):
10✔
1085
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1086
        self._transport = transport
10✔
1087
        self._protocol = protocol
10✔
1088
        self._receive_guard = ResourceGuard("reading from")
10✔
1089
        self._send_guard = ResourceGuard("writing to")
10✔
1090
        self._closed = False
10✔
1091

1092
    @property
10✔
1093
    def _raw_socket(self) -> socket.socket:
10✔
1094
        return self._transport.get_extra_info("socket")
10✔
1095

1096
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1097
        with self._receive_guard:
10✔
1098
            await AsyncIOBackend.checkpoint()
10✔
1099

1100
            if (
10✔
1101
                not self._protocol.read_event.is_set()
1102
                and not self._transport.is_closing()
1103
            ):
1104
                self._transport.resume_reading()
10✔
1105
                await self._protocol.read_event.wait()
10✔
1106
                self._transport.pause_reading()
10✔
1107

1108
            try:
10✔
1109
                chunk = self._protocol.read_queue.popleft()
10✔
1110
            except IndexError:
10✔
1111
                if self._closed:
10✔
1112
                    raise ClosedResourceError from None
10✔
1113
                elif self._protocol.exception:
10✔
1114
                    raise self._protocol.exception from None
10✔
1115
                else:
1116
                    raise EndOfStream from None
10✔
1117

1118
            if len(chunk) > max_bytes:
10✔
1119
                # Split the oversized chunk
1120
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1121
                self._protocol.read_queue.appendleft(leftover)
8✔
1122

1123
            # If the read queue is empty, clear the flag so that the next call will
1124
            # block until data is available
1125
            if not self._protocol.read_queue:
10✔
1126
                self._protocol.read_event.clear()
10✔
1127

1128
        return chunk
10✔
1129

1130
    async def send(self, item: bytes) -> None:
10✔
1131
        with self._send_guard:
10✔
1132
            await AsyncIOBackend.checkpoint()
10✔
1133

1134
            if self._closed:
10✔
1135
                raise ClosedResourceError
10✔
1136
            elif self._protocol.exception is not None:
10✔
1137
                raise self._protocol.exception
10✔
1138

1139
            try:
10✔
1140
                self._transport.write(item)
10✔
UNCOV
1141
            except RuntimeError as exc:
×
UNCOV
1142
                if self._transport.is_closing():
×
UNCOV
1143
                    raise BrokenResourceError from exc
×
1144
                else:
UNCOV
1145
                    raise
×
1146

1147
            await self._protocol.write_event.wait()
10✔
1148

1149
    async def send_eof(self) -> None:
10✔
1150
        try:
10✔
1151
            self._transport.write_eof()
10✔
UNCOV
1152
        except OSError:
×
UNCOV
1153
            pass
×
1154

1155
    async def aclose(self) -> None:
10✔
1156
        if not self._transport.is_closing():
10✔
1157
            self._closed = True
10✔
1158
            try:
10✔
1159
                self._transport.write_eof()
10✔
1160
            except OSError:
5✔
1161
                pass
5✔
1162

1163
            self._transport.close()
10✔
1164
            await sleep(0)
10✔
1165
            self._transport.abort()
10✔
1166

1167

1168
class _RawSocketMixin:
10✔
1169
    _receive_future: asyncio.Future | None = None
10✔
1170
    _send_future: asyncio.Future | None = None
10✔
1171
    _closing = False
10✔
1172

1173
    def __init__(self, raw_socket: socket.socket):
10✔
1174
        self.__raw_socket = raw_socket
7✔
1175
        self._receive_guard = ResourceGuard("reading from")
7✔
1176
        self._send_guard = ResourceGuard("writing to")
7✔
1177

1178
    @property
10✔
1179
    def _raw_socket(self) -> socket.socket:
10✔
1180
        return self.__raw_socket
7✔
1181

1182
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1183
        def callback(f: object) -> None:
7✔
1184
            del self._receive_future
7✔
1185
            loop.remove_reader(self.__raw_socket)
7✔
1186

1187
        f = self._receive_future = asyncio.Future()
7✔
1188
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1189
        f.add_done_callback(callback)
7✔
1190
        return f
7✔
1191

1192
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1193
        def callback(f: object) -> None:
7✔
1194
            del self._send_future
7✔
1195
            loop.remove_writer(self.__raw_socket)
7✔
1196

1197
        f = self._send_future = asyncio.Future()
7✔
1198
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1199
        f.add_done_callback(callback)
7✔
1200
        return f
7✔
1201

1202
    async def aclose(self) -> None:
10✔
1203
        if not self._closing:
7✔
1204
            self._closing = True
7✔
1205
            if self.__raw_socket.fileno() != -1:
7✔
1206
                self.__raw_socket.close()
7✔
1207

1208
            if self._receive_future:
7✔
1209
                self._receive_future.set_result(None)
7✔
1210
            if self._send_future:
7✔
UNCOV
1211
                self._send_future.set_result(None)
×
1212

1213

1214
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1215
    async def send_eof(self) -> None:
10✔
1216
        with self._send_guard:
7✔
1217
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1218

1219
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1220
        loop = get_running_loop()
7✔
1221
        await AsyncIOBackend.checkpoint()
7✔
1222
        with self._receive_guard:
7✔
1223
            while True:
4✔
1224
                try:
7✔
1225
                    data = self._raw_socket.recv(max_bytes)
7✔
1226
                except BlockingIOError:
7✔
1227
                    await self._wait_until_readable(loop)
7✔
1228
                except OSError as exc:
7✔
1229
                    if self._closing:
7✔
1230
                        raise ClosedResourceError from None
7✔
1231
                    else:
1232
                        raise BrokenResourceError from exc
1✔
1233
                else:
1234
                    if not data:
7✔
1235
                        raise EndOfStream
7✔
1236

1237
                    return data
7✔
1238

1239
    async def send(self, item: bytes) -> None:
10✔
1240
        loop = get_running_loop()
7✔
1241
        await AsyncIOBackend.checkpoint()
7✔
1242
        with self._send_guard:
7✔
1243
            view = memoryview(item)
7✔
1244
            while view:
7✔
1245
                try:
7✔
1246
                    bytes_sent = self._raw_socket.send(view)
7✔
1247
                except BlockingIOError:
7✔
1248
                    await self._wait_until_writable(loop)
7✔
1249
                except OSError as exc:
7✔
1250
                    if self._closing:
7✔
1251
                        raise ClosedResourceError from None
7✔
1252
                    else:
1253
                        raise BrokenResourceError from exc
1✔
1254
                else:
1255
                    view = view[bytes_sent:]
7✔
1256

1257
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1258
        if not isinstance(msglen, int) or msglen < 0:
7✔
1259
            raise ValueError("msglen must be a non-negative integer")
7✔
1260
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1261
            raise ValueError("maxfds must be a positive integer")
7✔
1262

1263
        loop = get_running_loop()
7✔
1264
        fds = array.array("i")
7✔
1265
        await AsyncIOBackend.checkpoint()
7✔
1266
        with self._receive_guard:
7✔
1267
            while True:
4✔
1268
                try:
7✔
1269
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1270
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1271
                    )
1272
                except BlockingIOError:
7✔
1273
                    await self._wait_until_readable(loop)
7✔
1274
                except OSError as exc:
×
UNCOV
1275
                    if self._closing:
×
UNCOV
1276
                        raise ClosedResourceError from None
×
1277
                    else:
UNCOV
1278
                        raise BrokenResourceError from exc
×
1279
                else:
1280
                    if not message and not ancdata:
7✔
UNCOV
1281
                        raise EndOfStream
×
1282

1283
                    break
4✔
1284

1285
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1286
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
UNCOV
1287
                raise RuntimeError(
×
1288
                    f"Received unexpected ancillary data; message = {message!r}, "
1289
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1290
                )
1291

1292
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1293

1294
        return message, list(fds)
7✔
1295

1296
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1297
        if not message:
7✔
1298
            raise ValueError("message must not be empty")
7✔
1299
        if not fds:
7✔
1300
            raise ValueError("fds must not be empty")
7✔
1301

1302
        loop = get_running_loop()
7✔
1303
        filenos: list[int] = []
7✔
1304
        for fd in fds:
7✔
1305
            if isinstance(fd, int):
7✔
UNCOV
1306
                filenos.append(fd)
×
1307
            elif isinstance(fd, IOBase):
7✔
1308
                filenos.append(fd.fileno())
7✔
1309

1310
        fdarray = array.array("i", filenos)
7✔
1311
        await AsyncIOBackend.checkpoint()
7✔
1312
        with self._send_guard:
7✔
1313
            while True:
4✔
1314
                try:
7✔
1315
                    # The ignore can be removed after mypy picks up
1316
                    # https://github.com/python/typeshed/pull/5545
1317
                    self._raw_socket.sendmsg(
7✔
1318
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1319
                    )
1320
                    break
7✔
UNCOV
1321
                except BlockingIOError:
×
UNCOV
1322
                    await self._wait_until_writable(loop)
×
UNCOV
1323
                except OSError as exc:
×
UNCOV
1324
                    if self._closing:
×
UNCOV
1325
                        raise ClosedResourceError from None
×
1326
                    else:
UNCOV
1327
                        raise BrokenResourceError from exc
×
1328

1329

1330
class TCPSocketListener(abc.SocketListener):
10✔
1331
    _accept_scope: CancelScope | None = None
10✔
1332
    _closed = False
10✔
1333

1334
    def __init__(self, raw_socket: socket.socket):
10✔
1335
        self.__raw_socket = raw_socket
10✔
1336
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1337
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1338

1339
    @property
10✔
1340
    def _raw_socket(self) -> socket.socket:
10✔
1341
        return self.__raw_socket
10✔
1342

1343
    async def accept(self) -> abc.SocketStream:
10✔
1344
        if self._closed:
10✔
1345
            raise ClosedResourceError
10✔
1346

1347
        with self._accept_guard:
10✔
1348
            await AsyncIOBackend.checkpoint()
10✔
1349
            with CancelScope() as self._accept_scope:
10✔
1350
                try:
10✔
1351
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1352
                except asyncio.CancelledError:
9✔
1353
                    # Workaround for https://bugs.python.org/issue41317
1354
                    try:
9✔
1355
                        self._loop.remove_reader(self._raw_socket)
9✔
1356
                    except (ValueError, NotImplementedError):
2✔
1357
                        pass
2✔
1358

1359
                    if self._closed:
9✔
1360
                        raise ClosedResourceError from None
9✔
1361

1362
                    raise
9✔
1363
                finally:
1364
                    self._accept_scope = None
10✔
1365

1366
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1367
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1368
            StreamProtocol, client_sock
1369
        )
1370
        return SocketStream(transport, protocol)
10✔
1371

1372
    async def aclose(self) -> None:
10✔
1373
        if self._closed:
10✔
1374
            return
10✔
1375

1376
        self._closed = True
10✔
1377
        if self._accept_scope:
10✔
1378
            # Workaround for https://bugs.python.org/issue41317
1379
            try:
10✔
1380
                self._loop.remove_reader(self._raw_socket)
10✔
1381
            except (ValueError, NotImplementedError):
2✔
1382
                pass
2✔
1383

1384
            self._accept_scope.cancel()
9✔
1385
            await sleep(0)
9✔
1386

1387
        self._raw_socket.close()
10✔
1388

1389

1390
class UNIXSocketListener(abc.SocketListener):
10✔
1391
    def __init__(self, raw_socket: socket.socket):
10✔
1392
        self.__raw_socket = raw_socket
7✔
1393
        self._loop = get_running_loop()
7✔
1394
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1395
        self._closed = False
7✔
1396

1397
    async def accept(self) -> abc.SocketStream:
10✔
1398
        await AsyncIOBackend.checkpoint()
7✔
1399
        with self._accept_guard:
7✔
1400
            while True:
4✔
1401
                try:
7✔
1402
                    client_sock, _ = self.__raw_socket.accept()
7✔
1403
                    client_sock.setblocking(False)
7✔
1404
                    return UNIXSocketStream(client_sock)
7✔
1405
                except BlockingIOError:
7✔
1406
                    f: asyncio.Future = asyncio.Future()
7✔
1407
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1408
                    f.add_done_callback(
7✔
1409
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1410
                    )
1411
                    await f
7✔
UNCOV
1412
                except OSError as exc:
×
UNCOV
1413
                    if self._closed:
×
UNCOV
1414
                        raise ClosedResourceError from None
×
1415
                    else:
1416
                        raise BrokenResourceError from exc
1✔
1417

1418
    async def aclose(self) -> None:
10✔
1419
        self._closed = True
7✔
1420
        self.__raw_socket.close()
7✔
1421

1422
    @property
10✔
1423
    def _raw_socket(self) -> socket.socket:
10✔
1424
        return self.__raw_socket
7✔
1425

1426

1427
class UDPSocket(abc.UDPSocket):
10✔
1428
    def __init__(
10✔
1429
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1430
    ):
1431
        self._transport = transport
9✔
1432
        self._protocol = protocol
9✔
1433
        self._receive_guard = ResourceGuard("reading from")
9✔
1434
        self._send_guard = ResourceGuard("writing to")
9✔
1435
        self._closed = False
9✔
1436

1437
    @property
10✔
1438
    def _raw_socket(self) -> socket.socket:
10✔
1439
        return self._transport.get_extra_info("socket")
9✔
1440

1441
    async def aclose(self) -> None:
10✔
1442
        if not self._transport.is_closing():
9✔
1443
            self._closed = True
9✔
1444
            self._transport.close()
9✔
1445

1446
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1447
        with self._receive_guard:
9✔
1448
            await AsyncIOBackend.checkpoint()
9✔
1449

1450
            # If the buffer is empty, ask for more data
1451
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1452
                self._protocol.read_event.clear()
9✔
1453
                await self._protocol.read_event.wait()
9✔
1454

1455
            try:
9✔
1456
                return self._protocol.read_queue.popleft()
9✔
1457
            except IndexError:
9✔
1458
                if self._closed:
9✔
1459
                    raise ClosedResourceError from None
9✔
1460
                else:
1461
                    raise BrokenResourceError from None
1✔
1462

1463
    async def send(self, item: UDPPacketType) -> None:
10✔
1464
        with self._send_guard:
9✔
1465
            await AsyncIOBackend.checkpoint()
9✔
1466
            await self._protocol.write_event.wait()
9✔
1467
            if self._closed:
9✔
1468
                raise ClosedResourceError
9✔
1469
            elif self._transport.is_closing():
9✔
UNCOV
1470
                raise BrokenResourceError
×
1471
            else:
1472
                self._transport.sendto(*item)
9✔
1473

1474

1475
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1476
    def __init__(
10✔
1477
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1478
    ):
1479
        self._transport = transport
9✔
1480
        self._protocol = protocol
9✔
1481
        self._receive_guard = ResourceGuard("reading from")
9✔
1482
        self._send_guard = ResourceGuard("writing to")
9✔
1483
        self._closed = False
9✔
1484

1485
    @property
10✔
1486
    def _raw_socket(self) -> socket.socket:
10✔
1487
        return self._transport.get_extra_info("socket")
9✔
1488

1489
    async def aclose(self) -> None:
10✔
1490
        if not self._transport.is_closing():
9✔
1491
            self._closed = True
9✔
1492
            self._transport.close()
9✔
1493

1494
    async def receive(self) -> bytes:
10✔
1495
        with self._receive_guard:
9✔
1496
            await AsyncIOBackend.checkpoint()
9✔
1497

1498
            # If the buffer is empty, ask for more data
1499
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1500
                self._protocol.read_event.clear()
9✔
1501
                await self._protocol.read_event.wait()
9✔
1502

1503
            try:
9✔
1504
                packet = self._protocol.read_queue.popleft()
9✔
1505
            except IndexError:
9✔
1506
                if self._closed:
9✔
1507
                    raise ClosedResourceError from None
9✔
1508
                else:
UNCOV
1509
                    raise BrokenResourceError from None
×
1510

1511
            return packet[0]
9✔
1512

1513
    async def send(self, item: bytes) -> None:
10✔
1514
        with self._send_guard:
9✔
1515
            await AsyncIOBackend.checkpoint()
9✔
1516
            await self._protocol.write_event.wait()
9✔
1517
            if self._closed:
9✔
1518
                raise ClosedResourceError
9✔
1519
            elif self._transport.is_closing():
9✔
UNCOV
1520
                raise BrokenResourceError
×
1521
            else:
1522
                self._transport.sendto(item)
9✔
1523

1524

1525
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1526
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1527
        loop = get_running_loop()
7✔
1528
        await AsyncIOBackend.checkpoint()
7✔
1529
        with self._receive_guard:
7✔
1530
            while True:
4✔
1531
                try:
7✔
1532
                    data = self._raw_socket.recvfrom(65536)
7✔
1533
                except BlockingIOError:
7✔
1534
                    await self._wait_until_readable(loop)
7✔
1535
                except OSError as exc:
7✔
1536
                    if self._closing:
7✔
1537
                        raise ClosedResourceError from None
7✔
1538
                    else:
1539
                        raise BrokenResourceError from exc
1✔
1540
                else:
1541
                    return data
7✔
1542

1543
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1544
        loop = get_running_loop()
7✔
1545
        await AsyncIOBackend.checkpoint()
7✔
1546
        with self._send_guard:
7✔
1547
            while True:
4✔
1548
                try:
7✔
1549
                    self._raw_socket.sendto(*item)
7✔
1550
                except BlockingIOError:
7✔
UNCOV
1551
                    await self._wait_until_writable(loop)
×
1552
                except OSError as exc:
7✔
1553
                    if self._closing:
7✔
1554
                        raise ClosedResourceError from None
7✔
1555
                    else:
1556
                        raise BrokenResourceError from exc
1✔
1557
                else:
1558
                    return
7✔
1559

1560

1561
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1562
    async def receive(self) -> bytes:
10✔
1563
        loop = get_running_loop()
7✔
1564
        await AsyncIOBackend.checkpoint()
7✔
1565
        with self._receive_guard:
7✔
1566
            while True:
4✔
1567
                try:
7✔
1568
                    data = self._raw_socket.recv(65536)
7✔
1569
                except BlockingIOError:
7✔
1570
                    await self._wait_until_readable(loop)
7✔
1571
                except OSError as exc:
7✔
1572
                    if self._closing:
7✔
1573
                        raise ClosedResourceError from None
7✔
1574
                    else:
1575
                        raise BrokenResourceError from exc
1✔
1576
                else:
1577
                    return data
7✔
1578

1579
    async def send(self, item: bytes) -> None:
10✔
1580
        loop = get_running_loop()
7✔
1581
        await AsyncIOBackend.checkpoint()
7✔
1582
        with self._send_guard:
7✔
1583
            while True:
4✔
1584
                try:
7✔
1585
                    self._raw_socket.send(item)
7✔
1586
                except BlockingIOError:
7✔
UNCOV
1587
                    await self._wait_until_writable(loop)
×
1588
                except OSError as exc:
7✔
1589
                    if self._closing:
7✔
1590
                        raise ClosedResourceError from None
7✔
1591
                    else:
1592
                        raise BrokenResourceError from exc
1✔
1593
                else:
1594
                    return
7✔
1595

1596

1597
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1598
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1599

1600

1601
#
1602
# Synchronization
1603
#
1604

1605

1606
class Event(BaseEvent):
10✔
1607
    def __new__(cls) -> Event:
10✔
1608
        return object.__new__(cls)
10✔
1609

1610
    def __init__(self) -> None:
10✔
1611
        self._event = asyncio.Event()
10✔
1612

1613
    def set(self) -> None:
10✔
1614
        self._event.set()
10✔
1615

1616
    def is_set(self) -> bool:
10✔
1617
        return self._event.is_set()
10✔
1618

1619
    async def wait(self) -> None:
10✔
1620
        if self.is_set():
10✔
1621
            await AsyncIOBackend.checkpoint()
10✔
1622
        else:
1623
            await self._event.wait()
10✔
1624

1625
    def statistics(self) -> EventStatistics:
10✔
1626
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1627

1628

1629
class CapacityLimiter(BaseCapacityLimiter):
10✔
1630
    _total_tokens: float = 0
10✔
1631

1632
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1633
        return object.__new__(cls)
10✔
1634

1635
    def __init__(self, total_tokens: float):
10✔
1636
        self._borrowers: set[Any] = set()
10✔
1637
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1638
        self.total_tokens = total_tokens
10✔
1639

1640
    async def __aenter__(self) -> None:
10✔
1641
        await self.acquire()
10✔
1642

1643
    async def __aexit__(
10✔
1644
        self,
1645
        exc_type: type[BaseException] | None,
1646
        exc_val: BaseException | None,
1647
        exc_tb: TracebackType | None,
1648
    ) -> None:
1649
        self.release()
10✔
1650

1651
    @property
10✔
1652
    def total_tokens(self) -> float:
10✔
1653
        return self._total_tokens
9✔
1654

1655
    @total_tokens.setter
10✔
1656
    def total_tokens(self, value: float) -> None:
10✔
1657
        if not isinstance(value, int) and not math.isinf(value):
10✔
1658
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1659
        if value < 1:
10✔
1660
            raise ValueError("total_tokens must be >= 1")
9✔
1661

1662
        old_value = self._total_tokens
10✔
1663
        self._total_tokens = value
10✔
1664
        events = []
10✔
1665
        for event in self._wait_queue.values():
10✔
1666
            if value <= old_value:
9✔
UNCOV
1667
                break
×
1668

1669
            if not event.is_set():
9✔
1670
                events.append(event)
9✔
1671
                old_value += 1
9✔
1672

1673
        for event in events:
10✔
1674
            event.set()
9✔
1675

1676
    @property
10✔
1677
    def borrowed_tokens(self) -> int:
10✔
1678
        return len(self._borrowers)
9✔
1679

1680
    @property
10✔
1681
    def available_tokens(self) -> float:
10✔
1682
        return self._total_tokens - len(self._borrowers)
9✔
1683

1684
    def acquire_nowait(self) -> None:
10✔
UNCOV
1685
        self.acquire_on_behalf_of_nowait(current_task())
×
1686

1687
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1688
        if borrower in self._borrowers:
10✔
1689
            raise RuntimeError(
9✔
1690
                "this borrower is already holding one of this CapacityLimiter's "
1691
                "tokens"
1692
            )
1693

1694
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1695
            raise WouldBlock
9✔
1696

1697
        self._borrowers.add(borrower)
10✔
1698

1699
    async def acquire(self) -> None:
10✔
1700
        return await self.acquire_on_behalf_of(current_task())
10✔
1701

1702
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1703
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1704
        try:
10✔
1705
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1706
        except WouldBlock:
9✔
1707
            event = asyncio.Event()
9✔
1708
            self._wait_queue[borrower] = event
9✔
1709
            try:
9✔
1710
                await event.wait()
9✔
UNCOV
1711
            except BaseException:
×
UNCOV
1712
                self._wait_queue.pop(borrower, None)
×
UNCOV
1713
                raise
×
1714

1715
            self._borrowers.add(borrower)
9✔
1716
        else:
1717
            try:
10✔
1718
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1719
            except BaseException:
9✔
1720
                self.release()
9✔
1721
                raise
9✔
1722

1723
    def release(self) -> None:
10✔
1724
        self.release_on_behalf_of(current_task())
10✔
1725

1726
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1727
        try:
10✔
1728
            self._borrowers.remove(borrower)
10✔
1729
        except KeyError:
9✔
1730
            raise RuntimeError(
9✔
1731
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1732
            ) from None
1733

1734
        # Notify the next task in line if this limiter has free capacity now
1735
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1736
            event = self._wait_queue.popitem(last=False)[1]
9✔
1737
            event.set()
9✔
1738

1739
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1740
        return CapacityLimiterStatistics(
9✔
1741
            self.borrowed_tokens,
1742
            self.total_tokens,
1743
            tuple(self._borrowers),
1744
            len(self._wait_queue),
1745
        )
1746

1747

1748
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1749

1750

1751
#
1752
# Operating system signals
1753
#
1754

1755

1756
class _SignalReceiver:
10✔
1757
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1758
        self._signals = signals
8✔
1759
        self._loop = get_running_loop()
8✔
1760
        self._signal_queue: deque[Signals] = deque()
8✔
1761
        self._future: asyncio.Future = asyncio.Future()
8✔
1762
        self._handled_signals: set[Signals] = set()
8✔
1763

1764
    def _deliver(self, signum: Signals) -> None:
10✔
1765
        self._signal_queue.append(signum)
8✔
1766
        if not self._future.done():
8✔
1767
            self._future.set_result(None)
8✔
1768

1769
    def __enter__(self) -> _SignalReceiver:
10✔
1770
        for sig in set(self._signals):
8✔
1771
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1772
            self._handled_signals.add(sig)
8✔
1773

1774
        return self
8✔
1775

1776
    def __exit__(
10✔
1777
        self,
1778
        exc_type: type[BaseException] | None,
1779
        exc_val: BaseException | None,
1780
        exc_tb: TracebackType | None,
1781
    ) -> bool | None:
1782
        for sig in self._handled_signals:
8✔
1783
            self._loop.remove_signal_handler(sig)
8✔
1784
        return None
8✔
1785

1786
    def __aiter__(self) -> _SignalReceiver:
10✔
1787
        return self
8✔
1788

1789
    async def __anext__(self) -> Signals:
10✔
1790
        await AsyncIOBackend.checkpoint()
8✔
1791
        if not self._signal_queue:
8✔
UNCOV
1792
            self._future = asyncio.Future()
×
UNCOV
1793
            await self._future
×
1794

1795
        return self._signal_queue.popleft()
8✔
1796

1797

1798
#
1799
# Testing and debugging
1800
#
1801

1802

1803
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1804
    task_state = _task_states.get(task)
10✔
1805
    if task_state is None:
10✔
1806
        parent_id = None
10✔
1807
    else:
1808
        parent_id = task_state.parent_id
10✔
1809

1810
    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
10✔
1811

1812

1813
class TestRunner(abc.TestRunner):
10✔
1814
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1815

1816
    def __init__(
10✔
1817
        self,
1818
        *,
1819
        debug: bool | None = None,
1820
        use_uvloop: bool = False,
1821
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
1822
    ) -> None:
1823
        if use_uvloop and loop_factory is None:
10✔
UNCOV
1824
            import uvloop
×
1825

UNCOV
1826
            loop_factory = uvloop.new_event_loop
×
1827

1828
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
1829
        self._exceptions: list[BaseException] = []
10✔
1830
        self._runner_task: asyncio.Task | None = None
10✔
1831

1832
    def __enter__(self) -> TestRunner:
10✔
1833
        self._runner.__enter__()
10✔
1834
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
1835
        return self
10✔
1836

1837
    def __exit__(
10✔
1838
        self,
1839
        exc_type: type[BaseException] | None,
1840
        exc_val: BaseException | None,
1841
        exc_tb: TracebackType | None,
1842
    ) -> None:
1843
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
1844

1845
    def get_loop(self) -> AbstractEventLoop:
10✔
1846
        return self._runner.get_loop()
10✔
1847

1848
    def _exception_handler(
10✔
1849
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1850
    ) -> None:
1851
        if isinstance(context.get("exception"), Exception):
10✔
1852
            self._exceptions.append(context["exception"])
10✔
1853
        else:
1854
            loop.default_exception_handler(context)
10✔
1855

1856
    def _raise_async_exceptions(self) -> None:
10✔
1857
        # Re-raise any exceptions raised in asynchronous callbacks
1858
        if self._exceptions:
10✔
1859
            exceptions, self._exceptions = self._exceptions, []
10✔
1860
            if len(exceptions) == 1:
10✔
1861
                raise exceptions[0]
10✔
UNCOV
1862
            elif exceptions:
×
UNCOV
1863
                raise BaseExceptionGroup(
×
1864
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1865
                )
1866

1867
    @staticmethod
10✔
1868
    async def _run_tests_and_fixtures(
10✔
1869
        receive_stream: MemoryObjectReceiveStream[
1870
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
1871
        ],
1872
    ) -> None:
1873
        with receive_stream:
10✔
1874
            async for coro, future in receive_stream:
10✔
1875
                try:
10✔
1876
                    retval = await coro
10✔
1877
                except BaseException as exc:
10✔
1878
                    if not future.cancelled():
10✔
1879
                        future.set_exception(exc)
10✔
1880
                else:
1881
                    if not future.cancelled():
10✔
1882
                        future.set_result(retval)
10✔
1883

1884
    async def _call_in_runner_task(
10✔
1885
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1886
    ) -> T_Retval:
1887
        if not self._runner_task:
10✔
1888
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
1889
                Tuple[Awaitable[Any], asyncio.Future]
1890
            ](1)
1891
            self._runner_task = self.get_loop().create_task(
10✔
1892
                self._run_tests_and_fixtures(receive_stream)
1893
            )
1894

1895
        coro = func(*args, **kwargs)
10✔
1896
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
1897
        self._send_stream.send_nowait((coro, future))
10✔
1898
        return await future
10✔
1899

1900
    def run_asyncgen_fixture(
10✔
1901
        self,
1902
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1903
        kwargs: dict[str, Any],
1904
    ) -> Iterable[T_Retval]:
1905
        asyncgen = fixture_func(**kwargs)
10✔
1906
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
1907
            self._call_in_runner_task(asyncgen.asend, None)
1908
        )
1909
        self._raise_async_exceptions()
10✔
1910

1911
        yield fixturevalue
10✔
1912

1913
        try:
10✔
1914
            self.get_loop().run_until_complete(
10✔
1915
                self._call_in_runner_task(asyncgen.asend, None)
1916
            )
1917
        except StopAsyncIteration:
10✔
1918
            self._raise_async_exceptions()
10✔
1919
        else:
UNCOV
1920
            self.get_loop().run_until_complete(asyncgen.aclose())
×
UNCOV
1921
            raise RuntimeError("Async generator fixture did not stop")
×
1922

1923
    def run_fixture(
10✔
1924
        self,
1925
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1926
        kwargs: dict[str, Any],
1927
    ) -> T_Retval:
1928
        retval = self.get_loop().run_until_complete(
10✔
1929
            self._call_in_runner_task(fixture_func, **kwargs)
1930
        )
1931
        self._raise_async_exceptions()
10✔
1932
        return retval
10✔
1933

1934
    def run_test(
10✔
1935
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1936
    ) -> None:
1937
        try:
10✔
1938
            self.get_loop().run_until_complete(
10✔
1939
                self._call_in_runner_task(test_func, **kwargs)
1940
            )
1941
        except Exception as exc:
10✔
1942
            self._exceptions.append(exc)
9✔
1943

1944
        self._raise_async_exceptions()
10✔
1945

1946

1947
class AsyncIOBackend(AsyncBackend):
10✔
1948
    @classmethod
10✔
1949
    def run(
10✔
1950
        cls,
1951
        func: Callable[..., Awaitable[T_Retval]],
1952
        args: tuple,
1953
        kwargs: dict[str, Any],
1954
        options: dict[str, Any],
1955
    ) -> T_Retval:
1956
        @wraps(func)
10✔
1957
        async def wrapper() -> T_Retval:
10✔
1958
            task = cast(asyncio.Task, current_task())
10✔
1959
            task.set_name(get_callable_name(func))
10✔
1960
            _task_states[task] = TaskState(None, None)
10✔
1961

1962
            try:
10✔
1963
                return await func(*args)
10✔
1964
            finally:
1965
                del _task_states[task]
10✔
1966

1967
        debug = options.get("debug", False)
10✔
1968
        options.get("loop_factory", None)
10✔
1969
        options.get("use_uvloop", False)
10✔
1970
        return native_run(wrapper(), debug=debug)
10✔
1971

1972
    @classmethod
10✔
1973
    def current_token(cls) -> object:
10✔
1974
        return get_running_loop()
10✔
1975

1976
    @classmethod
10✔
1977
    def current_time(cls) -> float:
10✔
1978
        return get_running_loop().time()
10✔
1979

1980
    @classmethod
10✔
1981
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
1982
        return CancelledError
10✔
1983

1984
    @classmethod
10✔
1985
    async def checkpoint(cls) -> None:
10✔
1986
        await sleep(0)
10✔
1987

1988
    @classmethod
10✔
1989
    async def checkpoint_if_cancelled(cls) -> None:
10✔
1990
        task = current_task()
10✔
1991
        if task is None:
10✔
UNCOV
1992
            return
×
1993

1994
        try:
10✔
1995
            cancel_scope = _task_states[task].cancel_scope
10✔
1996
        except KeyError:
10✔
1997
            return
10✔
1998

1999
        while cancel_scope:
10✔
2000
            if cancel_scope.cancel_called:
10✔
2001
                await sleep(0)
10✔
2002
            elif cancel_scope.shield:
10✔
2003
                break
9✔
2004
            else:
2005
                cancel_scope = cancel_scope._parent_scope
10✔
2006

2007
    @classmethod
10✔
2008
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
2009
        with CancelScope(shield=True):
10✔
2010
            await sleep(0)
10✔
2011

2012
    @classmethod
10✔
2013
    async def sleep(cls, delay: float) -> None:
10✔
2014
        await sleep(delay)
10✔
2015

2016
    @classmethod
10✔
2017
    def create_cancel_scope(
10✔
2018
        cls, *, deadline: float = math.inf, shield: bool = False
2019
    ) -> CancelScope:
2020
        return CancelScope(deadline=deadline, shield=shield)
10✔
2021

2022
    @classmethod
10✔
2023
    def current_effective_deadline(cls) -> float:
10✔
2024
        try:
9✔
2025
            cancel_scope = _task_states[
9✔
2026
                current_task()  # type: ignore[index]
2027
            ].cancel_scope
UNCOV
2028
        except KeyError:
×
UNCOV
2029
            return math.inf
×
2030

2031
        deadline = math.inf
9✔
2032
        while cancel_scope:
9✔
2033
            deadline = min(deadline, cancel_scope.deadline)
9✔
2034
            if cancel_scope._cancel_called:
9✔
2035
                deadline = -math.inf
9✔
2036
                break
9✔
2037
            elif cancel_scope.shield:
9✔
2038
                break
9✔
2039
            else:
2040
                cancel_scope = cancel_scope._parent_scope
9✔
2041

2042
        return deadline
9✔
2043

2044
    @classmethod
10✔
2045
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2046
        return TaskGroup()
10✔
2047

2048
    @classmethod
10✔
2049
    def create_event(cls) -> abc.Event:
10✔
2050
        return Event()
10✔
2051

2052
    @classmethod
10✔
2053
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2054
        return CapacityLimiter(total_tokens)
9✔
2055

2056
    @classmethod
10✔
2057
    async def run_sync_in_worker_thread(
10✔
2058
        cls,
2059
        func: Callable[..., T_Retval],
2060
        args: tuple[Any, ...],
2061
        cancellable: bool = False,
2062
        limiter: abc.CapacityLimiter | None = None,
2063
    ) -> T_Retval:
2064
        await cls.checkpoint()
10✔
2065

2066
        # If this is the first run in this event loop thread, set up the necessary
2067
        # variables
2068
        try:
10✔
2069
            idle_workers = _threadpool_idle_workers.get()
10✔
2070
            workers = _threadpool_workers.get()
10✔
2071
        except LookupError:
10✔
2072
            idle_workers = deque()
10✔
2073
            workers = set()
10✔
2074
            _threadpool_idle_workers.set(idle_workers)
10✔
2075
            _threadpool_workers.set(workers)
10✔
2076

2077
        async with limiter or cls.current_default_thread_limiter():
10✔
2078
            with CancelScope(shield=not cancellable):
10✔
2079
                future: asyncio.Future = asyncio.Future()
10✔
2080
                root_task = find_root_task()
10✔
2081
                if not idle_workers:
10✔
2082
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2083
                    worker.start()
10✔
2084
                    workers.add(worker)
10✔
2085
                    root_task.add_done_callback(worker.stop)
10✔
2086
                else:
2087
                    worker = idle_workers.pop()
10✔
2088

2089
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2090
                    # seconds or longer
2091
                    now = cls.current_time()
10✔
2092
                    while idle_workers:
10✔
2093
                        if (
9✔
2094
                            now - idle_workers[0].idle_since
2095
                            < WorkerThread.MAX_IDLE_TIME
2096
                        ):
2097
                            break
9✔
2098

UNCOV
2099
                        expired_worker = idle_workers.popleft()
×
UNCOV
2100
                        expired_worker.root_task.remove_done_callback(
×
2101
                            expired_worker.stop
2102
                        )
UNCOV
2103
                        expired_worker.stop()
×
2104

2105
                context = copy_context()
10✔
2106
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2107
                worker.queue.put_nowait((context, func, args, future))
10✔
2108
                return await future
10✔
2109

2110
    @classmethod
10✔
2111
    def run_async_from_thread(
10✔
2112
        cls,
2113
        func: Callable[..., Awaitable[T_Retval]],
2114
        args: tuple[Any, ...],
2115
        token: object,
2116
    ) -> T_Retval:
2117
        loop = cast(AbstractEventLoop, token)
10✔
2118
        context = copy_context()
10✔
2119
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2120
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2121
            asyncio.run_coroutine_threadsafe, func(*args), loop
2122
        )
2123
        return f.result()
10✔
2124

2125
    @classmethod
10✔
2126
    def run_sync_from_thread(
10✔
2127
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2128
    ) -> T_Retval:
2129
        @wraps(func)
10✔
2130
        def wrapper() -> None:
10✔
2131
            try:
10✔
2132
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2133
                f.set_result(func(*args))
10✔
2134
            except BaseException as exc:
10✔
2135
                f.set_exception(exc)
10✔
2136
                if not isinstance(exc, Exception):
10✔
UNCOV
2137
                    raise
×
2138

2139
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2140
        loop = cast(AbstractEventLoop, token)
10✔
2141
        loop.call_soon_threadsafe(wrapper)
10✔
2142
        return f.result()
10✔
2143

2144
    @classmethod
10✔
2145
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2146
        return BlockingPortal()
10✔
2147

2148
    @classmethod
10✔
2149
    async def open_process(
10✔
2150
        cls,
2151
        command: str | bytes | Sequence[str | bytes],
2152
        *,
2153
        shell: bool,
2154
        stdin: int | IO[Any] | None,
2155
        stdout: int | IO[Any] | None,
2156
        stderr: int | IO[Any] | None,
2157
        cwd: str | bytes | PathLike | None = None,
2158
        env: Mapping[str, str] | None = None,
2159
        start_new_session: bool = False,
2160
    ) -> Process:
2161
        await cls.checkpoint()
9✔
2162
        if shell:
9✔
2163
            process = await asyncio.create_subprocess_shell(
9✔
2164
                cast("str | bytes", command),
2165
                stdin=stdin,
2166
                stdout=stdout,
2167
                stderr=stderr,
2168
                cwd=cwd,
2169
                env=env,
2170
                start_new_session=start_new_session,
2171
            )
2172
        else:
2173
            process = await asyncio.create_subprocess_exec(
9✔
2174
                *command,
2175
                stdin=stdin,
2176
                stdout=stdout,
2177
                stderr=stderr,
2178
                cwd=cwd,
2179
                env=env,
2180
                start_new_session=start_new_session,
2181
            )
2182

2183
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2184
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2185
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2186
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2187

2188
    @classmethod
10✔
2189
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2190
        create_task(
9✔
2191
            _shutdown_process_pool_on_exit(workers),
2192
            name="AnyIO process pool shutdown task",
2193
        )
2194
        find_root_task().add_done_callback(
9✔
2195
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2196
        )
2197

2198
    @classmethod
10✔
2199
    async def connect_tcp(
10✔
2200
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2201
    ) -> abc.SocketStream:
2202
        transport, protocol = cast(
10✔
2203
            Tuple[asyncio.Transport, StreamProtocol],
2204
            await get_running_loop().create_connection(
2205
                StreamProtocol, host, port, local_addr=local_address
2206
            ),
2207
        )
2208
        transport.pause_reading()
10✔
2209
        return SocketStream(transport, protocol)
10✔
2210

2211
    @classmethod
10✔
2212
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
10✔
2213
        await cls.checkpoint()
7✔
2214
        loop = get_running_loop()
7✔
2215
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2216
        raw_socket.setblocking(False)
7✔
2217
        while True:
4✔
2218
            try:
7✔
2219
                raw_socket.connect(path)
7✔
2220
            except BlockingIOError:
7✔
UNCOV
2221
                f: asyncio.Future = asyncio.Future()
×
UNCOV
2222
                loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2223
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2224
                await f
×
2225
            except BaseException:
7✔
2226
                raw_socket.close()
7✔
2227
                raise
7✔
2228
            else:
2229
                return UNIXSocketStream(raw_socket)
7✔
2230

2231
    @classmethod
10✔
2232
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2233
        return TCPSocketListener(sock)
10✔
2234

2235
    @classmethod
10✔
2236
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2237
        return UNIXSocketListener(sock)
7✔
2238

2239
    @classmethod
10✔
2240
    async def create_udp_socket(
10✔
2241
        cls,
2242
        family: AddressFamily,
2243
        local_address: IPSockAddrType | None,
2244
        remote_address: IPSockAddrType | None,
2245
        reuse_port: bool,
2246
    ) -> UDPSocket | ConnectedUDPSocket:
2247
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2248
            DatagramProtocol,
2249
            local_addr=local_address,
2250
            remote_addr=remote_address,
2251
            family=family,
2252
            reuse_port=reuse_port,
2253
        )
2254
        if protocol.exception:
9✔
UNCOV
2255
            transport.close()
×
UNCOV
2256
            raise protocol.exception
×
2257

2258
        if not remote_address:
9✔
2259
            return UDPSocket(transport, protocol)
9✔
2260
        else:
2261
            return ConnectedUDPSocket(transport, protocol)
9✔
2262

2263
    @classmethod
10✔
2264
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2265
        cls, raw_socket: socket.socket, remote_path: str | None
2266
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2267
        await cls.checkpoint()
7✔
2268
        loop = get_running_loop()
7✔
2269

2270
        if remote_path:
7✔
2271
            while True:
4✔
2272
                try:
7✔
2273
                    raw_socket.connect(remote_path)
7✔
UNCOV
2274
                except BlockingIOError:
×
UNCOV
2275
                    f: asyncio.Future = asyncio.Future()
×
UNCOV
2276
                    loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2277
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2278
                    await f
×
UNCOV
2279
                except BaseException:
×
UNCOV
2280
                    raw_socket.close()
×
UNCOV
2281
                    raise
×
2282
                else:
2283
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2284
        else:
2285
            return UNIXDatagramSocket(raw_socket)
7✔
2286

2287
    @classmethod
10✔
2288
    async def getaddrinfo(
10✔
2289
        cls,
2290
        host: bytes | str | None,
2291
        port: str | int | None,
2292
        *,
2293
        family: int | AddressFamily = 0,
2294
        type: int | SocketKind = 0,
2295
        proto: int = 0,
2296
        flags: int = 0,
2297
    ) -> list[
2298
        tuple[
2299
            AddressFamily,
2300
            SocketKind,
2301
            int,
2302
            str,
2303
            tuple[str, int] | tuple[str, int, int, int],
2304
        ]
2305
    ]:
2306
        return await get_running_loop().getaddrinfo(
10✔
2307
            host, port, family=family, type=type, proto=proto, flags=flags
2308
        )
2309

2310
    @classmethod
10✔
2311
    async def getnameinfo(
10✔
2312
        cls, sockaddr: IPSockAddrType, flags: int = 0
2313
    ) -> tuple[str, str]:
2314
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2315

2316
    @classmethod
10✔
2317
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2318
        await cls.checkpoint()
×
2319
        try:
×
UNCOV
2320
            read_events = _read_events.get()
×
2321
        except LookupError:
×
2322
            read_events = {}
×
2323
            _read_events.set(read_events)
×
2324

2325
        if read_events.get(sock):
×
UNCOV
2326
            raise BusyResourceError("reading from") from None
×
2327

2328
        loop = get_running_loop()
×
UNCOV
2329
        event = read_events[sock] = asyncio.Event()
×
UNCOV
2330
        loop.add_reader(sock, event.set)
×
UNCOV
2331
        try:
×
2332
            await event.wait()
×
2333
        finally:
2334
            if read_events.pop(sock, None) is not None:
×
2335
                loop.remove_reader(sock)
×
2336
                readable = True
×
2337
            else:
UNCOV
2338
                readable = False
×
2339

2340
        if not readable:
×
UNCOV
2341
            raise ClosedResourceError
×
2342

2343
    @classmethod
10✔
2344
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2345
        await cls.checkpoint()
×
2346
        try:
×
UNCOV
2347
            write_events = _write_events.get()
×
2348
        except LookupError:
×
2349
            write_events = {}
×
2350
            _write_events.set(write_events)
×
2351

2352
        if write_events.get(sock):
×
UNCOV
2353
            raise BusyResourceError("writing to") from None
×
2354

2355
        loop = get_running_loop()
×
UNCOV
2356
        event = write_events[sock] = asyncio.Event()
×
UNCOV
2357
        loop.add_writer(sock.fileno(), event.set)
×
UNCOV
2358
        try:
×
UNCOV
2359
            await event.wait()
×
2360
        finally:
UNCOV
2361
            if write_events.pop(sock, None) is not None:
×
UNCOV
2362
                loop.remove_writer(sock)
×
UNCOV
2363
                writable = True
×
2364
            else:
UNCOV
2365
                writable = False
×
2366

UNCOV
2367
        if not writable:
×
UNCOV
2368
            raise ClosedResourceError
×
2369

2370
    @classmethod
10✔
2371
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2372
        try:
10✔
2373
            return _default_thread_limiter.get()
10✔
2374
        except LookupError:
10✔
2375
            limiter = CapacityLimiter(40)
10✔
2376
            _default_thread_limiter.set(limiter)
10✔
2377
            return limiter
10✔
2378

2379
    @classmethod
10✔
2380
    def open_signal_receiver(
10✔
2381
        cls, *signals: Signals
2382
    ) -> ContextManager[AsyncIterator[Signals]]:
2383
        return _SignalReceiver(signals)
8✔
2384

2385
    @classmethod
10✔
2386
    def get_current_task(cls) -> TaskInfo:
10✔
2387
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2388

2389
    @classmethod
10✔
2390
    def get_running_tasks(cls) -> list[TaskInfo]:
10✔
2391
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2392

2393
    @classmethod
10✔
2394
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2395
        await cls.checkpoint()
10✔
2396
        this_task = current_task()
10✔
2397
        while True:
6✔
2398
            for task in all_tasks():
10✔
2399
                if task is this_task:
10✔
2400
                    continue
10✔
2401

2402
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2403
                if waiter is None or waiter.done():
10✔
2404
                    await sleep(0.1)
10✔
2405
                    break
10✔
2406
            else:
2407
                return
10✔
2408

2409
    @classmethod
10✔
2410
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2411
        return TestRunner(**options)
10✔
2412

2413

2414
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc