• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 5636305520

pending completion
5636305520

Pull #591

github

web-flow
Merge 0ce892f75 into 9a464bf68
Pull Request #591: Improved detection of timeout based vs direct cancellation of cancel scopes

17 of 18 new or added lines in 3 files covered. (94.44%)

139 existing lines in 3 files now uncovered.

4249 of 4715 relevant lines covered (90.12%)

8.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.82
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextvars import Context, copy_context
10✔
25
from dataclasses import dataclass
10✔
26
from functools import partial, wraps
10✔
27
from inspect import (
10✔
28
    CORO_RUNNING,
29
    CORO_SUSPENDED,
30
    getcoroutinestate,
31
    iscoroutine,
32
)
33
from io import IOBase
10✔
34
from os import PathLike
10✔
35
from queue import Queue
10✔
36
from signal import Signals
10✔
37
from socket import AddressFamily, SocketKind
10✔
38
from threading import Thread
10✔
39
from types import TracebackType
10✔
40
from typing import (
10✔
41
    IO,
42
    Any,
43
    AsyncGenerator,
44
    Awaitable,
45
    Callable,
46
    Collection,
47
    ContextManager,
48
    Coroutine,
49
    Mapping,
50
    Optional,
51
    Sequence,
52
    Tuple,
53
    TypeVar,
54
    cast,
55
)
56
from weakref import WeakKeyDictionary
10✔
57

58
import sniffio
10✔
59

60
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
61
from .._core._eventloop import claim_worker_thread
10✔
62
from .._core._exceptions import (
10✔
63
    BrokenResourceError,
64
    BusyResourceError,
65
    ClosedResourceError,
66
    EndOfStream,
67
    WouldBlock,
68
)
69
from .._core._sockets import convert_ipv6_sockaddr
10✔
70
from .._core._streams import create_memory_object_stream
10✔
71
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
72
from .._core._synchronization import Event as BaseEvent
10✔
73
from .._core._synchronization import ResourceGuard
10✔
74
from .._core._tasks import CancelScope as BaseCancelScope
10✔
75
from ..abc import (
10✔
76
    AsyncBackend,
77
    IPSockAddrType,
78
    SocketListener,
79
    UDPPacketType,
80
    UNIXDatagramPacketType,
81
)
82
from ..lowlevel import RunVar
10✔
83
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
84

85
if sys.version_info >= (3, 11):
10✔
86
    from asyncio import Runner
5✔
87
else:
88
    import contextvars
6✔
89
    import enum
6✔
90
    import signal
6✔
91
    from asyncio import coroutines, events, exceptions, tasks
6✔
92

93
    from exceptiongroup import BaseExceptionGroup, ExceptionGroup
6✔
94

95
    class _State(enum.Enum):
6✔
96
        CREATED = "created"
6✔
97
        INITIALIZED = "initialized"
6✔
98
        CLOSED = "closed"
6✔
99

100
    class Runner:
6✔
101
        # Copied from CPython 3.11
102
        def __init__(
6✔
103
            self,
104
            *,
105
            debug: bool | None = None,
106
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
107
        ):
108
            self._state = _State.CREATED
6✔
109
            self._debug = debug
6✔
110
            self._loop_factory = loop_factory
6✔
111
            self._loop: AbstractEventLoop | None = None
6✔
112
            self._context = None
6✔
113
            self._interrupt_count = 0
6✔
114
            self._set_event_loop = False
6✔
115

116
        def __enter__(self) -> Runner:
6✔
117
            self._lazy_init()
6✔
118
            return self
6✔
119

120
        def __exit__(
6✔
121
            self,
122
            exc_type: type[BaseException],
123
            exc_val: BaseException,
124
            exc_tb: TracebackType,
125
        ) -> None:
126
            self.close()
6✔
127

128
        def close(self) -> None:
6✔
129
            """Shutdown and close event loop."""
130
            if self._state is not _State.INITIALIZED:
6✔
131
                return
×
132
            try:
6✔
133
                loop = self._loop
6✔
134
                _cancel_all_tasks(loop)
6✔
135
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
136
                if hasattr(loop, "shutdown_default_executor"):
6✔
137
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
138
                else:
139
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
140
            finally:
141
                if self._set_event_loop:
6✔
142
                    events.set_event_loop(None)
6✔
143
                loop.close()
6✔
144
                self._loop = None
6✔
145
                self._state = _State.CLOSED
6✔
146

147
        def get_loop(self) -> AbstractEventLoop:
6✔
148
            """Return embedded event loop."""
149
            self._lazy_init()
6✔
150
            return self._loop
6✔
151

152
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
153
            """Run a coroutine inside the embedded event loop."""
154
            if not coroutines.iscoroutine(coro):
×
155
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
156

157
            if events._get_running_loop() is not None:
×
158
                # fail fast with short traceback
159
                raise RuntimeError(
×
160
                    "Runner.run() cannot be called from a running event loop"
161
                )
162

163
            self._lazy_init()
×
164

165
            if context is None:
×
166
                context = self._context
×
167
            task = self._loop.create_task(coro, context=context)
×
168

169
            if (
×
170
                threading.current_thread() is threading.main_thread()
171
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
172
            ):
173
                sigint_handler = partial(self._on_sigint, main_task=task)
×
174
                try:
×
175
                    signal.signal(signal.SIGINT, sigint_handler)
×
176
                except ValueError:
×
177
                    # `signal.signal` may throw if `threading.main_thread` does
178
                    # not support signals (e.g. embedded interpreter with signals
179
                    # not registered - see gh-91880)
180
                    sigint_handler = None
×
181
            else:
182
                sigint_handler = None
×
183

184
            self._interrupt_count = 0
×
185
            try:
×
186
                return self._loop.run_until_complete(task)
×
187
            except exceptions.CancelledError:
×
188
                if self._interrupt_count > 0:
×
189
                    uncancel = getattr(task, "uncancel", None)
×
190
                    if uncancel is not None and uncancel() == 0:
×
191
                        raise KeyboardInterrupt()
×
192
                raise  # CancelledError
×
193
            finally:
194
                if (
×
195
                    sigint_handler is not None
196
                    and signal.getsignal(signal.SIGINT) is sigint_handler
197
                ):
198
                    signal.signal(signal.SIGINT, signal.default_int_handler)
×
199

200
        def _lazy_init(self) -> None:
6✔
201
            if self._state is _State.CLOSED:
6✔
202
                raise RuntimeError("Runner is closed")
×
203
            if self._state is _State.INITIALIZED:
6✔
204
                return
6✔
205
            if self._loop_factory is None:
6✔
206
                self._loop = events.new_event_loop()
6✔
207
                if not self._set_event_loop:
6✔
208
                    # Call set_event_loop only once to avoid calling
209
                    # attach_loop multiple times on child watchers
210
                    events.set_event_loop(self._loop)
6✔
211
                    self._set_event_loop = True
6✔
212
            else:
213
                self._loop = self._loop_factory()
4✔
214
            if self._debug is not None:
6✔
215
                self._loop.set_debug(self._debug)
6✔
216
            self._context = contextvars.copy_context()
6✔
217
            self._state = _State.INITIALIZED
6✔
218

219
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
220
            self._interrupt_count += 1
×
221
            if self._interrupt_count == 1 and not main_task.done():
×
222
                main_task.cancel()
×
223
                # wakeup loop if it is blocked by select() with long timeout
224
                self._loop.call_soon_threadsafe(lambda: None)
×
225
                return
×
226
            raise KeyboardInterrupt()
×
227

228
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
229
        to_cancel = tasks.all_tasks(loop)
6✔
230
        if not to_cancel:
6✔
231
            return
×
232

233
        for task in to_cancel:
6✔
234
            task.cancel()
6✔
235

236
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
237

238
        for task in to_cancel:
6✔
239
            if task.cancelled():
6✔
240
                continue
6✔
241
            if task.exception() is not None:
5✔
242
                loop.call_exception_handler(
2✔
243
                    {
244
                        "message": "unhandled exception during asyncio.run() shutdown",
245
                        "exception": task.exception(),
246
                        "task": task,
247
                    }
248
                )
249

250
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
251
        """Schedule the shutdown of the default executor."""
252

253
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
254
            try:
3✔
255
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
256
                loop.call_soon_threadsafe(future.set_result, None)
3✔
257
            except Exception as ex:
×
258
                loop.call_soon_threadsafe(future.set_exception, ex)
×
259

260
        loop._executor_shutdown_called = True
3✔
261
        if loop._default_executor is None:
3✔
262
            return
3✔
263
        future = loop.create_future()
3✔
264
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
265
        thread.start()
3✔
266
        try:
3✔
267
            await future
3✔
268
        finally:
269
            thread.join()
3✔
270

271

272
T_Retval = TypeVar("T_Retval")
10✔
273
T_contra = TypeVar("T_contra", contravariant=True)
10✔
274

275
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
276

277

278
def find_root_task() -> asyncio.Task:
10✔
279
    root_task = _root_task.get(None)
10✔
280
    if root_task is not None and not root_task.done():
10✔
281
        return root_task
10✔
282

283
    # Look for a task that has been started via run_until_complete()
284
    for task in all_tasks():
10✔
285
        if task._callbacks and not task.done():
10✔
286
            callbacks = [cb for cb, context in task._callbacks]
10✔
287
            for cb in callbacks:
10✔
288
                if (
10✔
289
                    cb is _run_until_complete_cb
290
                    or getattr(cb, "__module__", None) == "uvloop.loop"
291
                ):
292
                    _root_task.set(task)
10✔
293
                    return task
10✔
294

295
    # Look up the topmost task in the AnyIO task tree, if possible
296
    task = cast(asyncio.Task, current_task())
9✔
297
    state = _task_states.get(task)
9✔
298
    if state:
9✔
299
        cancel_scope = state.cancel_scope
9✔
300
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
301
            cancel_scope = cancel_scope._parent_scope
×
302

303
        if cancel_scope is not None:
9✔
304
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
305

306
    return task
×
307

308

309
def get_callable_name(func: Callable) -> str:
10✔
310
    module = getattr(func, "__module__", None)
10✔
311
    qualname = getattr(func, "__qualname__", None)
10✔
312
    return ".".join([x for x in (module, qualname) if x])
10✔
313

314

315
#
316
# Event loop
317
#
318

319
_run_vars = (
10✔
320
    WeakKeyDictionary()
321
)  # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
322

323

324
def _task_started(task: asyncio.Task) -> bool:
10✔
325
    """Return ``True`` if the task has been started and has not finished."""
326
    try:
10✔
327
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
328
    except AttributeError:
×
329
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
330
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
331

332

333
#
334
# Timeouts and cancellation
335
#
336

337

338
class CancelScope(BaseCancelScope):
10✔
339
    def __new__(
10✔
340
        cls, *, deadline: float = math.inf, shield: bool = False
341
    ) -> CancelScope:
342
        return object.__new__(cls)
10✔
343

344
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
345
        self._deadline = deadline
10✔
346
        self._shield = shield
10✔
347
        self._parent_scope: CancelScope | None = None
10✔
348
        self._cancel_called = False
10✔
349
        self._active = False
10✔
350
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
351
        self._cancel_handle: asyncio.Handle | None = None
10✔
352
        self._tasks: set[asyncio.Task] = set()
10✔
353
        self._host_task: asyncio.Task | None = None
10✔
354
        self._timeout_expired = False
10✔
355
        self._cancel_calls: int = 0
10✔
356

357
    def __enter__(self) -> CancelScope:
10✔
358
        if self._active:
10✔
359
            raise RuntimeError(
×
360
                "Each CancelScope may only be used for a single 'with' block"
361
            )
362

363
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
364
        self._tasks.add(host_task)
10✔
365
        try:
10✔
366
            task_state = _task_states[host_task]
10✔
367
        except KeyError:
10✔
368
            task_state = TaskState(None, self)
10✔
369
            _task_states[host_task] = task_state
10✔
370
        else:
371
            self._parent_scope = task_state.cancel_scope
10✔
372
            task_state.cancel_scope = self
10✔
373

374
        self._timeout()
10✔
375
        self._active = True
10✔
376

377
        # Start cancelling the host task if the scope was cancelled before entering
378
        if self._cancel_called:
10✔
379
            self._deliver_cancellation()
10✔
380

381
        return self
10✔
382

383
    def __exit__(
10✔
384
        self,
385
        exc_type: type[BaseException] | None,
386
        exc_val: BaseException | None,
387
        exc_tb: TracebackType | None,
388
    ) -> bool | None:
389
        if not self._active:
10✔
390
            raise RuntimeError("This cancel scope is not active")
9✔
391
        if current_task() is not self._host_task:
10✔
392
            raise RuntimeError(
9✔
393
                "Attempted to exit cancel scope in a different task than it was "
394
                "entered in"
395
            )
396

397
        assert self._host_task is not None
10✔
398
        host_task_state = _task_states.get(self._host_task)
10✔
399
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
400
            raise RuntimeError(
9✔
401
                "Attempted to exit a cancel scope that isn't the current tasks's "
402
                "current cancel scope"
403
            )
404

405
        self._active = False
10✔
406
        if self._timeout_handle:
10✔
407
            self._timeout_handle.cancel()
10✔
408
            self._timeout_handle = None
10✔
409

410
        self._tasks.remove(self._host_task)
10✔
411

412
        host_task_state.cancel_scope = self._parent_scope
10✔
413

414
        # Restart the cancellation effort in the farthest directly cancelled parent
415
        # scope if this one was shielded
416
        if self._shield:
10✔
417
            self._deliver_cancellation_to_parent()
10✔
418

419
        if exc_val is not None:
10✔
420
            exceptions = (
10✔
421
                exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
422
            )
423
            if all(isinstance(exc, CancelledError) for exc in exceptions):
10✔
424
                if self._timeout_expired:
10✔
425
                    return self._uncancel()
10✔
426
                elif not self._cancel_called:
10✔
427
                    # Task was cancelled natively
428
                    return None
10✔
429
                elif not self._parent_cancelled():
10✔
430
                    # This scope was directly cancelled
431
                    return self._uncancel()
10✔
432

433
        self._timeout_expired = False
10✔
434
        return None
10✔
435

436
    def _uncancel(self) -> bool:
10✔
437
        if sys.version_info < (3, 11) or self._host_task is None:
10✔
438
            self._cancel_calls = 0
6✔
439
            return True
6✔
440

441
        # Uncancel all AnyIO cancellations
442
        for i in range(self._cancel_calls):
4✔
443
            self._host_task.uncancel()
4✔
444

445
        self._cancel_calls = 0
4✔
446
        return not self._host_task.cancelling()
4✔
447

448
    def _timeout(self) -> None:
10✔
449
        if self._deadline != math.inf:
10✔
450
            loop = get_running_loop()
10✔
451
            if loop.time() >= self._deadline:
10✔
452
                self._timeout_expired = True
10✔
453
                self.cancel()
10✔
454
            else:
455
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
456

457
    def _deliver_cancellation(self) -> None:
10✔
458
        """
459
        Deliver cancellation to directly contained tasks and nested cancel scopes.
460

461
        Schedule another run at the end if we still have tasks eligible for
462
        cancellation.
463
        """
464
        should_retry = False
10✔
465
        current = current_task()
10✔
466
        for task in self._tasks:
10✔
467
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
468
                continue
9✔
469

470
            # The task is eligible for cancellation if it has started and is not in a
471
            # cancel scope shielded from this one
472
            cancel_scope = _task_states[task].cancel_scope
10✔
473
            while cancel_scope is not self:
10✔
474
                if cancel_scope is None or cancel_scope._shield:
10✔
475
                    break
10✔
476
                else:
477
                    cancel_scope = cancel_scope._parent_scope
10✔
478
            else:
479
                should_retry = True
10✔
480
                if task is not current and (
10✔
481
                    task is self._host_task or _task_started(task)
482
                ):
483
                    self._cancel_calls += 1
10✔
484
                    task.cancel()
10✔
485

486
        # Schedule another callback if there are still tasks left
487
        if should_retry:
10✔
488
            self._cancel_handle = get_running_loop().call_soon(
10✔
489
                self._deliver_cancellation
490
            )
491
        else:
492
            self._cancel_handle = None
10✔
493

494
    def _deliver_cancellation_to_parent(self) -> None:
10✔
495
        """Start cancellation effort in the farthest directly cancelled parent scope"""
496
        scope = self._parent_scope
10✔
497
        scope_to_cancel: CancelScope | None = None
10✔
498
        while scope is not None:
10✔
499
            if scope._cancel_called and scope._cancel_handle is None:
10✔
500
                scope_to_cancel = scope
9✔
501

502
            # No point in looking beyond any shielded scope
503
            if scope._shield:
10✔
504
                break
9✔
505

506
            scope = scope._parent_scope
10✔
507

508
        if scope_to_cancel is not None:
10✔
509
            scope_to_cancel._deliver_cancellation()
9✔
510

511
    def _parent_cancelled(self) -> bool:
10✔
512
        # Check whether any parent has been cancelled
513
        cancel_scope = self._parent_scope
10✔
514
        while cancel_scope is not None and not cancel_scope._shield:
10✔
515
            if cancel_scope._cancel_called:
10✔
516
                return True
8✔
517
            else:
518
                cancel_scope = cancel_scope._parent_scope
10✔
519

520
        return False
10✔
521

522
    def cancel(self) -> None:
10✔
523
        if not self._cancel_called:
10✔
524
            if self._timeout_handle:
10✔
525
                self._timeout_handle.cancel()
10✔
526
                self._timeout_handle = None
10✔
527

528
            self._cancel_called = True
10✔
529
            if self._host_task is not None:
10✔
530
                self._deliver_cancellation()
10✔
531

532
    @property
10✔
533
    def deadline(self) -> float:
10✔
534
        return self._deadline
9✔
535

536
    @deadline.setter
10✔
537
    def deadline(self, value: float) -> None:
10✔
538
        self._deadline = float(value)
9✔
539
        if self._timeout_handle is not None:
9✔
540
            self._timeout_handle.cancel()
9✔
541
            self._timeout_handle = None
9✔
542

543
        if self._active and not self._cancel_called:
9✔
544
            self._timeout()
9✔
545

546
    @property
10✔
547
    def deadline_reached(self) -> bool:
10✔
548
        return self._timeout_expired and not self._active
10✔
549

550
    @property
10✔
551
    def cancel_called(self) -> bool:
10✔
552
        return self._cancel_called
10✔
553

554
    @property
10✔
555
    def shield(self) -> bool:
10✔
556
        return self._shield
10✔
557

558
    @shield.setter
10✔
559
    def shield(self, value: bool) -> None:
10✔
560
        if self._shield != value:
9✔
561
            self._shield = value
9✔
562
            if not value:
9✔
563
                self._deliver_cancellation_to_parent()
9✔
564

565

566
#
567
# Task states
568
#
569

570

571
class TaskState:
10✔
572
    """
573
    Encapsulates auxiliary task information that cannot be added to the Task instance
574
    itself because there are no guarantees about its implementation.
575
    """
576

577
    __slots__ = "parent_id", "cancel_scope"
10✔
578

579
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
580
        self.parent_id = parent_id
10✔
581
        self.cancel_scope = cancel_scope
10✔
582

583

584
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
585

586

587
#
588
# Task groups
589
#
590

591

592
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
593
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
594
        self._future = future
10✔
595
        self._parent_id = parent_id
10✔
596

597
    def started(self, value: T_contra | None = None) -> None:
10✔
598
        try:
10✔
599
            self._future.set_result(value)
10✔
600
        except asyncio.InvalidStateError:
9✔
601
            raise RuntimeError(
9✔
602
                "called 'started' twice on the same task status"
603
            ) from None
604

605
        task = cast(asyncio.Task, current_task())
10✔
606
        _task_states[task].parent_id = self._parent_id
10✔
607

608

609
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
610
    exceptions = list(excgroup.exceptions)
×
UNCOV
611
    modified = False
×
612
    for i, exc in enumerate(exceptions):
×
UNCOV
613
        if isinstance(exc, BaseExceptionGroup):
×
UNCOV
614
            new_exc = collapse_exception_group(exc)
×
UNCOV
615
            if new_exc is not exc:
×
UNCOV
616
                modified = True
×
UNCOV
617
                exceptions[i] = new_exc
×
618

UNCOV
619
    if len(exceptions) == 1:
×
UNCOV
620
        return exceptions[0]
×
UNCOV
621
    elif modified:
×
UNCOV
622
        return excgroup.derive(exceptions)
×
623
    else:
UNCOV
624
        return excgroup
×
625

626

627
class TaskGroup(abc.TaskGroup):
10✔
628
    def __init__(self) -> None:
10✔
629
        self.cancel_scope: CancelScope = CancelScope()
10✔
630
        self._active = False
10✔
631
        self._exceptions: list[BaseException] = []
10✔
632

633
    async def __aenter__(self) -> TaskGroup:
10✔
634
        self.cancel_scope.__enter__()
10✔
635
        self._active = True
10✔
636
        return self
10✔
637

638
    async def __aexit__(
10✔
639
        self,
640
        exc_type: type[BaseException] | None,
641
        exc_val: BaseException | None,
642
        exc_tb: TracebackType | None,
643
    ) -> bool | None:
644
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
645
        if exc_val is not None:
10✔
646
            self.cancel_scope.cancel()
10✔
647
            if not isinstance(exc_val, CancelledError):
10✔
648
                self._exceptions.append(exc_val)
10✔
649

650
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
10✔
651
        waited_for_tasks_to_finish = bool(self.cancel_scope._tasks)
10✔
652
        while self.cancel_scope._tasks:
10✔
653
            try:
10✔
654
                await asyncio.wait(self.cancel_scope._tasks)
10✔
655
            except CancelledError as exc:
9✔
656
                # This task was cancelled natively; reraise the CancelledError later
657
                # unless this task was already interrupted by another exception
658
                self.cancel_scope.cancel()
9✔
659
                if cancelled_exc_while_waiting_tasks is None:
9✔
660
                    cancelled_exc_while_waiting_tasks = exc
9✔
661

662
        self._active = False
10✔
663
        if self._exceptions:
10✔
664
            raise BaseExceptionGroup(
10✔
665
                "unhandled errors in a TaskGroup", self._exceptions
666
            )
667

668
        # Raise the CancelledError received while waiting for child tasks to exit,
669
        # unless the context manager itself was previously exited with another
670
        # exception, or if any of the  child tasks raised an exception other than
671
        # CancelledError
672
        if cancelled_exc_while_waiting_tasks:
10✔
673
            if exc_val is None or ignore_exception:
9✔
674
                raise cancelled_exc_while_waiting_tasks
9✔
675

676
        # Yield control to the event loop here to ensure that there is at least one
677
        # yield point within __aexit__() (trio does the same)
678
        if not waited_for_tasks_to_finish:
10✔
679
            await AsyncIOBackend.checkpoint()
10✔
680

681
        return ignore_exception
10✔
682

683
    def _spawn(
10✔
684
        self,
685
        func: Callable[..., Awaitable[Any]],
686
        args: tuple,
687
        name: object,
688
        task_status_future: asyncio.Future | None = None,
689
    ) -> asyncio.Task:
690
        def task_done(_task: asyncio.Task) -> None:
10✔
691
            assert _task in self.cancel_scope._tasks
10✔
692
            self.cancel_scope._tasks.remove(_task)
10✔
693
            del _task_states[_task]
10✔
694

695
            try:
10✔
696
                exc = _task.exception()
10✔
697
            except CancelledError as e:
10✔
698
                while isinstance(e.__context__, CancelledError):
10✔
699
                    e = e.__context__
7✔
700

701
                exc = e
10✔
702

703
            if exc is not None:
10✔
704
                if task_status_future is None or task_status_future.done():
10✔
705
                    if not isinstance(exc, CancelledError):
10✔
706
                        self._exceptions.append(exc)
10✔
707

708
                    self.cancel_scope.cancel()
10✔
709
                else:
710
                    task_status_future.set_exception(exc)
9✔
711
            elif task_status_future is not None and not task_status_future.done():
10✔
712
                task_status_future.set_exception(
9✔
713
                    RuntimeError("Child exited without calling task_status.started()")
714
                )
715

716
        if not self._active:
10✔
717
            raise RuntimeError(
10✔
718
                "This task group is not active; no new tasks can be started."
719
            )
720

721
        kwargs = {}
10✔
722
        if task_status_future:
10✔
723
            parent_id = id(current_task())
10✔
724
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
725
                task_status_future, id(self.cancel_scope._host_task)
726
            )
727
        else:
728
            parent_id = id(self.cancel_scope._host_task)
10✔
729

730
        coro = func(*args, **kwargs)
10✔
731
        if not iscoroutine(coro):
10✔
732
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
733
            raise TypeError(
9✔
734
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
735
                f"the return value ({coro!r}) is not a coroutine object"
736
            )
737

738
        name = get_callable_name(func) if name is None else str(name)
10✔
739
        task = create_task(coro, name=name)
10✔
740
        task.add_done_callback(task_done)
10✔
741

742
        # Make the spawned task inherit the task group's cancel scope
743
        _task_states[task] = TaskState(
10✔
744
            parent_id=parent_id, cancel_scope=self.cancel_scope
745
        )
746
        self.cancel_scope._tasks.add(task)
10✔
747
        return task
10✔
748

749
    def start_soon(
10✔
750
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
751
    ) -> None:
752
        self._spawn(func, args, name)
10✔
753

754
    async def start(
10✔
755
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
756
    ) -> None:
757
        future: asyncio.Future = asyncio.Future()
10✔
758
        task = self._spawn(func, args, name, future)
10✔
759

760
        # If the task raises an exception after sending a start value without a switch
761
        # point between, the task group is cancelled and this method never proceeds to
762
        # process the completed future. That's why we have to have a shielded cancel
763
        # scope here.
764
        with CancelScope(shield=True):
10✔
765
            try:
10✔
766
                return await future
10✔
767
            except CancelledError:
9✔
768
                task.cancel()
9✔
769
                raise
9✔
770

771

772
#
773
# Threads
774
#
775

776
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
777

778

779
class WorkerThread(Thread):
10✔
780
    MAX_IDLE_TIME = 10  # seconds
10✔
781

782
    def __init__(
10✔
783
        self,
784
        root_task: asyncio.Task,
785
        workers: set[WorkerThread],
786
        idle_workers: deque[WorkerThread],
787
    ):
788
        super().__init__(name="AnyIO worker thread")
10✔
789
        self.root_task = root_task
10✔
790
        self.workers = workers
10✔
791
        self.idle_workers = idle_workers
10✔
792
        self.loop = root_task._loop
10✔
793
        self.queue: Queue[
10✔
794
            tuple[Context, Callable, tuple, asyncio.Future] | None
795
        ] = Queue(2)
796
        self.idle_since = AsyncIOBackend.current_time()
10✔
797
        self.stopping = False
10✔
798

799
    def _report_result(
10✔
800
        self, future: asyncio.Future, result: Any, exc: BaseException | None
801
    ) -> None:
802
        self.idle_since = AsyncIOBackend.current_time()
10✔
803
        if not self.stopping:
10✔
804
            self.idle_workers.append(self)
10✔
805

806
        if not future.cancelled():
10✔
807
            if exc is not None:
10✔
808
                if isinstance(exc, StopIteration):
10✔
809
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
810
                    new_exc.__cause__ = exc
9✔
811
                    exc = new_exc
9✔
812

813
                future.set_exception(exc)
10✔
814
            else:
815
                future.set_result(result)
10✔
816

817
    def run(self) -> None:
10✔
818
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
819
            while True:
6✔
820
                item = self.queue.get()
10✔
821
                if item is None:
10✔
822
                    # Shutdown command received
823
                    return
10✔
824

825
                context, func, args, future = item
10✔
826
                if not future.cancelled():
10✔
827
                    result = None
10✔
828
                    exception: BaseException | None = None
10✔
829
                    try:
10✔
830
                        result = context.run(func, *args)
10✔
831
                    except BaseException as exc:
10✔
832
                        exception = exc
10✔
833

834
                    if not self.loop.is_closed():
10✔
835
                        self.loop.call_soon_threadsafe(
10✔
836
                            self._report_result, future, result, exception
837
                        )
838

839
                self.queue.task_done()
10✔
840

841
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
842
        self.stopping = True
10✔
843
        self.queue.put_nowait(None)
10✔
844
        self.workers.discard(self)
10✔
845
        try:
10✔
846
            self.idle_workers.remove(self)
10✔
847
        except ValueError:
9✔
848
            pass
9✔
849

850

851
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
852
    "_threadpool_idle_workers"
853
)
854
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
855

856

857
class BlockingPortal(abc.BlockingPortal):
10✔
858
    def __new__(cls) -> BlockingPortal:
10✔
859
        return object.__new__(cls)
10✔
860

861
    def __init__(self) -> None:
10✔
862
        super().__init__()
10✔
863
        self._loop = get_running_loop()
10✔
864

865
    def _spawn_task_from_thread(
10✔
866
        self,
867
        func: Callable,
868
        args: tuple[Any, ...],
869
        kwargs: dict[str, Any],
870
        name: object,
871
        future: Future,
872
    ) -> None:
873
        AsyncIOBackend.run_sync_from_thread(
10✔
874
            partial(self._task_group.start_soon, name=name),
875
            (self._call_func, func, args, kwargs, future),
876
            self._loop,
877
        )
878

879

880
#
881
# Subprocesses
882
#
883

884

885
@dataclass(eq=False)
10✔
886
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
887
    _stream: asyncio.StreamReader
10✔
888

889
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
890
        data = await self._stream.read(max_bytes)
9✔
891
        if data:
9✔
892
            return data
9✔
893
        else:
894
            raise EndOfStream
9✔
895

896
    async def aclose(self) -> None:
10✔
897
        self._stream.feed_eof()
9✔
898

899

900
@dataclass(eq=False)
10✔
901
class StreamWriterWrapper(abc.ByteSendStream):
10✔
902
    _stream: asyncio.StreamWriter
10✔
903

904
    async def send(self, item: bytes) -> None:
10✔
905
        self._stream.write(item)
9✔
906
        await self._stream.drain()
9✔
907

908
    async def aclose(self) -> None:
10✔
909
        self._stream.close()
9✔
910

911

912
@dataclass(eq=False)
10✔
913
class Process(abc.Process):
10✔
914
    _process: asyncio.subprocess.Process
10✔
915
    _stdin: StreamWriterWrapper | None
10✔
916
    _stdout: StreamReaderWrapper | None
10✔
917
    _stderr: StreamReaderWrapper | None
10✔
918

919
    async def aclose(self) -> None:
10✔
920
        if self._stdin:
9✔
921
            await self._stdin.aclose()
9✔
922
        if self._stdout:
9✔
923
            await self._stdout.aclose()
9✔
924
        if self._stderr:
9✔
925
            await self._stderr.aclose()
9✔
926

927
        await self.wait()
9✔
928

929
    async def wait(self) -> int:
10✔
930
        return await self._process.wait()
9✔
931

932
    def terminate(self) -> None:
10✔
933
        self._process.terminate()
7✔
934

935
    def kill(self) -> None:
10✔
936
        self._process.kill()
9✔
937

938
    def send_signal(self, signal: int) -> None:
10✔
UNCOV
939
        self._process.send_signal(signal)
×
940

941
    @property
10✔
942
    def pid(self) -> int:
10✔
UNCOV
943
        return self._process.pid
×
944

945
    @property
10✔
946
    def returncode(self) -> int | None:
10✔
947
        return self._process.returncode
9✔
948

949
    @property
10✔
950
    def stdin(self) -> abc.ByteSendStream | None:
10✔
951
        return self._stdin
9✔
952

953
    @property
10✔
954
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
955
        return self._stdout
9✔
956

957
    @property
10✔
958
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
959
        return self._stderr
9✔
960

961

962
def _forcibly_shutdown_process_pool_on_exit(
10✔
963
    workers: set[Process], _task: object
964
) -> None:
965
    """
966
    Forcibly shuts down worker processes belonging to this event loop."""
967
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
968
    if sys.version_info < (3, 12):
9✔
969
        try:
8✔
970
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
971
        except NotImplementedError:
2✔
972
            pass
2✔
973

974
    # Close as much as possible (w/o async/await) to avoid warnings
975
    for process in workers:
9✔
976
        if process.returncode is None:
9✔
977
            continue
9✔
978

UNCOV
979
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
980
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
981
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
982
        process.kill()
×
UNCOV
983
        if child_watcher:
×
UNCOV
984
            child_watcher.remove_child_handler(process.pid)
×
985

986

987
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
988
    """
989
    Shuts down worker processes belonging to this event loop.
990

991
    NOTE: this only works when the event loop was started using asyncio.run() or
992
    anyio.run().
993

994
    """
995
    process: abc.Process
996
    try:
9✔
997
        await sleep(math.inf)
9✔
998
    except asyncio.CancelledError:
9✔
999
        for process in workers:
9✔
1000
            if process.returncode is None:
9✔
1001
                process.kill()
9✔
1002

1003
        for process in workers:
9✔
1004
            await process.aclose()
9✔
1005

1006

1007
#
1008
# Sockets and networking
1009
#
1010

1011

1012
class StreamProtocol(asyncio.Protocol):
10✔
1013
    read_queue: deque[bytes]
10✔
1014
    read_event: asyncio.Event
10✔
1015
    write_event: asyncio.Event
10✔
1016
    exception: Exception | None = None
10✔
1017

1018
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1019
        self.read_queue = deque()
10✔
1020
        self.read_event = asyncio.Event()
10✔
1021
        self.write_event = asyncio.Event()
10✔
1022
        self.write_event.set()
10✔
1023
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1024

1025
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1026
        if exc:
10✔
1027
            self.exception = BrokenResourceError()
10✔
1028
            self.exception.__cause__ = exc
10✔
1029

1030
        self.read_event.set()
10✔
1031
        self.write_event.set()
10✔
1032

1033
    def data_received(self, data: bytes) -> None:
10✔
1034
        self.read_queue.append(data)
10✔
1035
        self.read_event.set()
10✔
1036

1037
    def eof_received(self) -> bool | None:
10✔
1038
        self.read_event.set()
10✔
1039
        return True
10✔
1040

1041
    def pause_writing(self) -> None:
10✔
1042
        self.write_event = asyncio.Event()
10✔
1043

1044
    def resume_writing(self) -> None:
10✔
UNCOV
1045
        self.write_event.set()
×
1046

1047

1048
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1049
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1050
    read_event: asyncio.Event
10✔
1051
    write_event: asyncio.Event
10✔
1052
    exception: Exception | None = None
10✔
1053

1054
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1055
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1056
        self.read_event = asyncio.Event()
9✔
1057
        self.write_event = asyncio.Event()
9✔
1058
        self.write_event.set()
9✔
1059

1060
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1061
        self.read_event.set()
9✔
1062
        self.write_event.set()
9✔
1063

1064
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1065
        addr = convert_ipv6_sockaddr(addr)
9✔
1066
        self.read_queue.append((data, addr))
9✔
1067
        self.read_event.set()
9✔
1068

1069
    def error_received(self, exc: Exception) -> None:
10✔
UNCOV
1070
        self.exception = exc
×
1071

1072
    def pause_writing(self) -> None:
10✔
UNCOV
1073
        self.write_event.clear()
×
1074

1075
    def resume_writing(self) -> None:
10✔
UNCOV
1076
        self.write_event.set()
×
1077

1078

1079
class SocketStream(abc.SocketStream):
10✔
1080
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1081
        self._transport = transport
10✔
1082
        self._protocol = protocol
10✔
1083
        self._receive_guard = ResourceGuard("reading from")
10✔
1084
        self._send_guard = ResourceGuard("writing to")
10✔
1085
        self._closed = False
10✔
1086

1087
    @property
10✔
1088
    def _raw_socket(self) -> socket.socket:
10✔
1089
        return self._transport.get_extra_info("socket")
10✔
1090

1091
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1092
        with self._receive_guard:
10✔
1093
            await AsyncIOBackend.checkpoint()
10✔
1094

1095
            if (
10✔
1096
                not self._protocol.read_event.is_set()
1097
                and not self._transport.is_closing()
1098
            ):
1099
                self._transport.resume_reading()
10✔
1100
                await self._protocol.read_event.wait()
10✔
1101
                self._transport.pause_reading()
10✔
1102

1103
            try:
10✔
1104
                chunk = self._protocol.read_queue.popleft()
10✔
1105
            except IndexError:
10✔
1106
                if self._closed:
10✔
1107
                    raise ClosedResourceError from None
10✔
1108
                elif self._protocol.exception:
10✔
1109
                    raise self._protocol.exception from None
10✔
1110
                else:
1111
                    raise EndOfStream from None
10✔
1112

1113
            if len(chunk) > max_bytes:
10✔
1114
                # Split the oversized chunk
1115
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1116
                self._protocol.read_queue.appendleft(leftover)
8✔
1117

1118
            # If the read queue is empty, clear the flag so that the next call will
1119
            # block until data is available
1120
            if not self._protocol.read_queue:
10✔
1121
                self._protocol.read_event.clear()
10✔
1122

1123
        return chunk
10✔
1124

1125
    async def send(self, item: bytes) -> None:
10✔
1126
        with self._send_guard:
10✔
1127
            await AsyncIOBackend.checkpoint()
10✔
1128

1129
            if self._closed:
10✔
1130
                raise ClosedResourceError
10✔
1131
            elif self._protocol.exception is not None:
10✔
1132
                raise self._protocol.exception
10✔
1133

1134
            try:
10✔
1135
                self._transport.write(item)
10✔
1136
            except RuntimeError as exc:
×
UNCOV
1137
                if self._transport.is_closing():
×
UNCOV
1138
                    raise BrokenResourceError from exc
×
1139
                else:
UNCOV
1140
                    raise
×
1141

1142
            await self._protocol.write_event.wait()
10✔
1143

1144
    async def send_eof(self) -> None:
10✔
1145
        try:
10✔
1146
            self._transport.write_eof()
10✔
UNCOV
1147
        except OSError:
×
UNCOV
1148
            pass
×
1149

1150
    async def aclose(self) -> None:
10✔
1151
        if not self._transport.is_closing():
10✔
1152
            self._closed = True
10✔
1153
            try:
10✔
1154
                self._transport.write_eof()
10✔
1155
            except OSError:
5✔
1156
                pass
5✔
1157

1158
            self._transport.close()
10✔
1159
            await sleep(0)
10✔
1160
            self._transport.abort()
10✔
1161

1162

1163
class _RawSocketMixin:
10✔
1164
    _receive_future: asyncio.Future | None = None
10✔
1165
    _send_future: asyncio.Future | None = None
10✔
1166
    _closing = False
10✔
1167

1168
    def __init__(self, raw_socket: socket.socket):
10✔
1169
        self.__raw_socket = raw_socket
7✔
1170
        self._receive_guard = ResourceGuard("reading from")
7✔
1171
        self._send_guard = ResourceGuard("writing to")
7✔
1172

1173
    @property
10✔
1174
    def _raw_socket(self) -> socket.socket:
10✔
1175
        return self.__raw_socket
7✔
1176

1177
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1178
        def callback(f: object) -> None:
7✔
1179
            del self._receive_future
7✔
1180
            loop.remove_reader(self.__raw_socket)
7✔
1181

1182
        f = self._receive_future = asyncio.Future()
7✔
1183
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1184
        f.add_done_callback(callback)
7✔
1185
        return f
7✔
1186

1187
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1188
        def callback(f: object) -> None:
7✔
1189
            del self._send_future
7✔
1190
            loop.remove_writer(self.__raw_socket)
7✔
1191

1192
        f = self._send_future = asyncio.Future()
7✔
1193
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1194
        f.add_done_callback(callback)
7✔
1195
        return f
7✔
1196

1197
    async def aclose(self) -> None:
10✔
1198
        if not self._closing:
7✔
1199
            self._closing = True
7✔
1200
            if self.__raw_socket.fileno() != -1:
7✔
1201
                self.__raw_socket.close()
7✔
1202

1203
            if self._receive_future:
7✔
1204
                self._receive_future.set_result(None)
7✔
1205
            if self._send_future:
7✔
UNCOV
1206
                self._send_future.set_result(None)
×
1207

1208

1209
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1210
    async def send_eof(self) -> None:
10✔
1211
        with self._send_guard:
7✔
1212
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1213

1214
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1215
        loop = get_running_loop()
7✔
1216
        await AsyncIOBackend.checkpoint()
7✔
1217
        with self._receive_guard:
7✔
1218
            while True:
4✔
1219
                try:
7✔
1220
                    data = self._raw_socket.recv(max_bytes)
7✔
1221
                except BlockingIOError:
7✔
1222
                    await self._wait_until_readable(loop)
7✔
1223
                except OSError as exc:
7✔
1224
                    if self._closing:
7✔
1225
                        raise ClosedResourceError from None
7✔
1226
                    else:
1227
                        raise BrokenResourceError from exc
1✔
1228
                else:
1229
                    if not data:
7✔
1230
                        raise EndOfStream
7✔
1231

1232
                    return data
7✔
1233

1234
    async def send(self, item: bytes) -> None:
10✔
1235
        loop = get_running_loop()
7✔
1236
        await AsyncIOBackend.checkpoint()
7✔
1237
        with self._send_guard:
7✔
1238
            view = memoryview(item)
7✔
1239
            while view:
7✔
1240
                try:
7✔
1241
                    bytes_sent = self._raw_socket.send(view)
7✔
1242
                except BlockingIOError:
7✔
1243
                    await self._wait_until_writable(loop)
7✔
1244
                except OSError as exc:
7✔
1245
                    if self._closing:
7✔
1246
                        raise ClosedResourceError from None
7✔
1247
                    else:
1248
                        raise BrokenResourceError from exc
1✔
1249
                else:
1250
                    view = view[bytes_sent:]
7✔
1251

1252
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1253
        if not isinstance(msglen, int) or msglen < 0:
7✔
1254
            raise ValueError("msglen must be a non-negative integer")
7✔
1255
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1256
            raise ValueError("maxfds must be a positive integer")
7✔
1257

1258
        loop = get_running_loop()
7✔
1259
        fds = array.array("i")
7✔
1260
        await AsyncIOBackend.checkpoint()
7✔
1261
        with self._receive_guard:
7✔
1262
            while True:
4✔
1263
                try:
7✔
1264
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1265
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1266
                    )
1267
                except BlockingIOError:
7✔
1268
                    await self._wait_until_readable(loop)
7✔
UNCOV
1269
                except OSError as exc:
×
1270
                    if self._closing:
×
UNCOV
1271
                        raise ClosedResourceError from None
×
1272
                    else:
UNCOV
1273
                        raise BrokenResourceError from exc
×
1274
                else:
1275
                    if not message and not ancdata:
7✔
UNCOV
1276
                        raise EndOfStream
×
1277

1278
                    break
4✔
1279

1280
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1281
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
UNCOV
1282
                raise RuntimeError(
×
1283
                    f"Received unexpected ancillary data; message = {message!r}, "
1284
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1285
                )
1286

1287
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1288

1289
        return message, list(fds)
7✔
1290

1291
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1292
        if not message:
7✔
1293
            raise ValueError("message must not be empty")
7✔
1294
        if not fds:
7✔
1295
            raise ValueError("fds must not be empty")
7✔
1296

1297
        loop = get_running_loop()
7✔
1298
        filenos: list[int] = []
7✔
1299
        for fd in fds:
7✔
1300
            if isinstance(fd, int):
7✔
UNCOV
1301
                filenos.append(fd)
×
1302
            elif isinstance(fd, IOBase):
7✔
1303
                filenos.append(fd.fileno())
7✔
1304

1305
        fdarray = array.array("i", filenos)
7✔
1306
        await AsyncIOBackend.checkpoint()
7✔
1307
        with self._send_guard:
7✔
1308
            while True:
4✔
1309
                try:
7✔
1310
                    # The ignore can be removed after mypy picks up
1311
                    # https://github.com/python/typeshed/pull/5545
1312
                    self._raw_socket.sendmsg(
7✔
1313
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1314
                    )
1315
                    break
7✔
UNCOV
1316
                except BlockingIOError:
×
UNCOV
1317
                    await self._wait_until_writable(loop)
×
UNCOV
1318
                except OSError as exc:
×
UNCOV
1319
                    if self._closing:
×
UNCOV
1320
                        raise ClosedResourceError from None
×
1321
                    else:
UNCOV
1322
                        raise BrokenResourceError from exc
×
1323

1324

1325
class TCPSocketListener(abc.SocketListener):
10✔
1326
    _accept_scope: CancelScope | None = None
10✔
1327
    _closed = False
10✔
1328

1329
    def __init__(self, raw_socket: socket.socket):
10✔
1330
        self.__raw_socket = raw_socket
10✔
1331
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1332
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1333

1334
    @property
10✔
1335
    def _raw_socket(self) -> socket.socket:
10✔
1336
        return self.__raw_socket
10✔
1337

1338
    async def accept(self) -> abc.SocketStream:
10✔
1339
        if self._closed:
10✔
1340
            raise ClosedResourceError
10✔
1341

1342
        with self._accept_guard:
10✔
1343
            await AsyncIOBackend.checkpoint()
10✔
1344
            with CancelScope() as self._accept_scope:
10✔
1345
                try:
10✔
1346
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1347
                except asyncio.CancelledError:
9✔
1348
                    # Workaround for https://bugs.python.org/issue41317
1349
                    try:
9✔
1350
                        self._loop.remove_reader(self._raw_socket)
9✔
1351
                    except (ValueError, NotImplementedError):
2✔
1352
                        pass
2✔
1353

1354
                    if self._closed:
9✔
1355
                        raise ClosedResourceError from None
9✔
1356

1357
                    raise
9✔
1358
                finally:
1359
                    self._accept_scope = None
10✔
1360

1361
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1362
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1363
            StreamProtocol, client_sock
1364
        )
1365
        return SocketStream(transport, protocol)
10✔
1366

1367
    async def aclose(self) -> None:
10✔
1368
        if self._closed:
10✔
1369
            return
10✔
1370

1371
        self._closed = True
10✔
1372
        if self._accept_scope:
10✔
1373
            # Workaround for https://bugs.python.org/issue41317
1374
            try:
10✔
1375
                self._loop.remove_reader(self._raw_socket)
10✔
1376
            except (ValueError, NotImplementedError):
2✔
1377
                pass
2✔
1378

1379
            self._accept_scope.cancel()
9✔
1380
            await sleep(0)
9✔
1381

1382
        self._raw_socket.close()
10✔
1383

1384

1385
class UNIXSocketListener(abc.SocketListener):
10✔
1386
    def __init__(self, raw_socket: socket.socket):
10✔
1387
        self.__raw_socket = raw_socket
7✔
1388
        self._loop = get_running_loop()
7✔
1389
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1390
        self._closed = False
7✔
1391

1392
    async def accept(self) -> abc.SocketStream:
10✔
1393
        await AsyncIOBackend.checkpoint()
7✔
1394
        with self._accept_guard:
7✔
1395
            while True:
4✔
1396
                try:
7✔
1397
                    client_sock, _ = self.__raw_socket.accept()
7✔
1398
                    client_sock.setblocking(False)
7✔
1399
                    return UNIXSocketStream(client_sock)
7✔
1400
                except BlockingIOError:
7✔
1401
                    f: asyncio.Future = asyncio.Future()
7✔
1402
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1403
                    f.add_done_callback(
7✔
1404
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1405
                    )
1406
                    await f
7✔
UNCOV
1407
                except OSError as exc:
×
UNCOV
1408
                    if self._closed:
×
UNCOV
1409
                        raise ClosedResourceError from None
×
1410
                    else:
1411
                        raise BrokenResourceError from exc
1✔
1412

1413
    async def aclose(self) -> None:
10✔
1414
        self._closed = True
7✔
1415
        self.__raw_socket.close()
7✔
1416

1417
    @property
10✔
1418
    def _raw_socket(self) -> socket.socket:
10✔
1419
        return self.__raw_socket
7✔
1420

1421

1422
class UDPSocket(abc.UDPSocket):
10✔
1423
    def __init__(
10✔
1424
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1425
    ):
1426
        self._transport = transport
9✔
1427
        self._protocol = protocol
9✔
1428
        self._receive_guard = ResourceGuard("reading from")
9✔
1429
        self._send_guard = ResourceGuard("writing to")
9✔
1430
        self._closed = False
9✔
1431

1432
    @property
10✔
1433
    def _raw_socket(self) -> socket.socket:
10✔
1434
        return self._transport.get_extra_info("socket")
9✔
1435

1436
    async def aclose(self) -> None:
10✔
1437
        if not self._transport.is_closing():
9✔
1438
            self._closed = True
9✔
1439
            self._transport.close()
9✔
1440

1441
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1442
        with self._receive_guard:
9✔
1443
            await AsyncIOBackend.checkpoint()
9✔
1444

1445
            # If the buffer is empty, ask for more data
1446
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1447
                self._protocol.read_event.clear()
9✔
1448
                await self._protocol.read_event.wait()
9✔
1449

1450
            try:
9✔
1451
                return self._protocol.read_queue.popleft()
9✔
1452
            except IndexError:
9✔
1453
                if self._closed:
9✔
1454
                    raise ClosedResourceError from None
9✔
1455
                else:
1456
                    raise BrokenResourceError from None
1✔
1457

1458
    async def send(self, item: UDPPacketType) -> None:
10✔
1459
        with self._send_guard:
9✔
1460
            await AsyncIOBackend.checkpoint()
9✔
1461
            await self._protocol.write_event.wait()
9✔
1462
            if self._closed:
9✔
1463
                raise ClosedResourceError
9✔
1464
            elif self._transport.is_closing():
9✔
UNCOV
1465
                raise BrokenResourceError
×
1466
            else:
1467
                self._transport.sendto(*item)
9✔
1468

1469

1470
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1471
    def __init__(
10✔
1472
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1473
    ):
1474
        self._transport = transport
9✔
1475
        self._protocol = protocol
9✔
1476
        self._receive_guard = ResourceGuard("reading from")
9✔
1477
        self._send_guard = ResourceGuard("writing to")
9✔
1478
        self._closed = False
9✔
1479

1480
    @property
10✔
1481
    def _raw_socket(self) -> socket.socket:
10✔
1482
        return self._transport.get_extra_info("socket")
9✔
1483

1484
    async def aclose(self) -> None:
10✔
1485
        if not self._transport.is_closing():
9✔
1486
            self._closed = True
9✔
1487
            self._transport.close()
9✔
1488

1489
    async def receive(self) -> bytes:
10✔
1490
        with self._receive_guard:
9✔
1491
            await AsyncIOBackend.checkpoint()
9✔
1492

1493
            # If the buffer is empty, ask for more data
1494
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1495
                self._protocol.read_event.clear()
9✔
1496
                await self._protocol.read_event.wait()
9✔
1497

1498
            try:
9✔
1499
                packet = self._protocol.read_queue.popleft()
9✔
1500
            except IndexError:
9✔
1501
                if self._closed:
9✔
1502
                    raise ClosedResourceError from None
9✔
1503
                else:
UNCOV
1504
                    raise BrokenResourceError from None
×
1505

1506
            return packet[0]
9✔
1507

1508
    async def send(self, item: bytes) -> None:
10✔
1509
        with self._send_guard:
9✔
1510
            await AsyncIOBackend.checkpoint()
9✔
1511
            await self._protocol.write_event.wait()
9✔
1512
            if self._closed:
9✔
1513
                raise ClosedResourceError
9✔
1514
            elif self._transport.is_closing():
9✔
UNCOV
1515
                raise BrokenResourceError
×
1516
            else:
1517
                self._transport.sendto(item)
9✔
1518

1519

1520
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1521
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1522
        loop = get_running_loop()
7✔
1523
        await AsyncIOBackend.checkpoint()
7✔
1524
        with self._receive_guard:
7✔
1525
            while True:
4✔
1526
                try:
7✔
1527
                    data = self._raw_socket.recvfrom(65536)
7✔
1528
                except BlockingIOError:
7✔
1529
                    await self._wait_until_readable(loop)
7✔
1530
                except OSError as exc:
7✔
1531
                    if self._closing:
7✔
1532
                        raise ClosedResourceError from None
7✔
1533
                    else:
1534
                        raise BrokenResourceError from exc
1✔
1535
                else:
1536
                    return data
7✔
1537

1538
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1539
        loop = get_running_loop()
7✔
1540
        await AsyncIOBackend.checkpoint()
7✔
1541
        with self._send_guard:
7✔
1542
            while True:
4✔
1543
                try:
7✔
1544
                    self._raw_socket.sendto(*item)
7✔
1545
                except BlockingIOError:
7✔
UNCOV
1546
                    await self._wait_until_writable(loop)
×
1547
                except OSError as exc:
7✔
1548
                    if self._closing:
7✔
1549
                        raise ClosedResourceError from None
7✔
1550
                    else:
1551
                        raise BrokenResourceError from exc
1✔
1552
                else:
1553
                    return
7✔
1554

1555

1556
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1557
    async def receive(self) -> bytes:
10✔
1558
        loop = get_running_loop()
7✔
1559
        await AsyncIOBackend.checkpoint()
7✔
1560
        with self._receive_guard:
7✔
1561
            while True:
4✔
1562
                try:
7✔
1563
                    data = self._raw_socket.recv(65536)
7✔
1564
                except BlockingIOError:
7✔
1565
                    await self._wait_until_readable(loop)
7✔
1566
                except OSError as exc:
7✔
1567
                    if self._closing:
7✔
1568
                        raise ClosedResourceError from None
7✔
1569
                    else:
1570
                        raise BrokenResourceError from exc
1✔
1571
                else:
1572
                    return data
7✔
1573

1574
    async def send(self, item: bytes) -> None:
10✔
1575
        loop = get_running_loop()
7✔
1576
        await AsyncIOBackend.checkpoint()
7✔
1577
        with self._send_guard:
7✔
1578
            while True:
4✔
1579
                try:
7✔
1580
                    self._raw_socket.send(item)
7✔
1581
                except BlockingIOError:
7✔
UNCOV
1582
                    await self._wait_until_writable(loop)
×
1583
                except OSError as exc:
7✔
1584
                    if self._closing:
7✔
1585
                        raise ClosedResourceError from None
7✔
1586
                    else:
1587
                        raise BrokenResourceError from exc
1✔
1588
                else:
1589
                    return
7✔
1590

1591

1592
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1593
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1594

1595

1596
#
1597
# Synchronization
1598
#
1599

1600

1601
class Event(BaseEvent):
10✔
1602
    def __new__(cls) -> Event:
10✔
1603
        return object.__new__(cls)
10✔
1604

1605
    def __init__(self) -> None:
10✔
1606
        self._event = asyncio.Event()
10✔
1607

1608
    def set(self) -> None:
10✔
1609
        self._event.set()
10✔
1610

1611
    def is_set(self) -> bool:
10✔
1612
        return self._event.is_set()
9✔
1613

1614
    async def wait(self) -> None:
10✔
1615
        if await self._event.wait():
10✔
1616
            await AsyncIOBackend.checkpoint()
10✔
1617

1618
    def statistics(self) -> EventStatistics:
10✔
1619
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1620

1621

1622
class CapacityLimiter(BaseCapacityLimiter):
10✔
1623
    _total_tokens: float = 0
10✔
1624

1625
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1626
        return object.__new__(cls)
10✔
1627

1628
    def __init__(self, total_tokens: float):
10✔
1629
        self._borrowers: set[Any] = set()
10✔
1630
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1631
        self.total_tokens = total_tokens
10✔
1632

1633
    async def __aenter__(self) -> None:
10✔
1634
        await self.acquire()
10✔
1635

1636
    async def __aexit__(
10✔
1637
        self,
1638
        exc_type: type[BaseException] | None,
1639
        exc_val: BaseException | None,
1640
        exc_tb: TracebackType | None,
1641
    ) -> None:
1642
        self.release()
10✔
1643

1644
    @property
10✔
1645
    def total_tokens(self) -> float:
10✔
1646
        return self._total_tokens
9✔
1647

1648
    @total_tokens.setter
10✔
1649
    def total_tokens(self, value: float) -> None:
10✔
1650
        if not isinstance(value, int) and not math.isinf(value):
10✔
1651
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1652
        if value < 1:
10✔
1653
            raise ValueError("total_tokens must be >= 1")
9✔
1654

1655
        old_value = self._total_tokens
10✔
1656
        self._total_tokens = value
10✔
1657
        events = []
10✔
1658
        for event in self._wait_queue.values():
10✔
1659
            if value <= old_value:
9✔
UNCOV
1660
                break
×
1661

1662
            if not event.is_set():
9✔
1663
                events.append(event)
9✔
1664
                old_value += 1
9✔
1665

1666
        for event in events:
10✔
1667
            event.set()
9✔
1668

1669
    @property
10✔
1670
    def borrowed_tokens(self) -> int:
10✔
1671
        return len(self._borrowers)
9✔
1672

1673
    @property
10✔
1674
    def available_tokens(self) -> float:
10✔
1675
        return self._total_tokens - len(self._borrowers)
9✔
1676

1677
    def acquire_nowait(self) -> None:
10✔
UNCOV
1678
        self.acquire_on_behalf_of_nowait(current_task())
×
1679

1680
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1681
        if borrower in self._borrowers:
10✔
1682
            raise RuntimeError(
9✔
1683
                "this borrower is already holding one of this CapacityLimiter's "
1684
                "tokens"
1685
            )
1686

1687
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1688
            raise WouldBlock
9✔
1689

1690
        self._borrowers.add(borrower)
10✔
1691

1692
    async def acquire(self) -> None:
10✔
1693
        return await self.acquire_on_behalf_of(current_task())
10✔
1694

1695
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1696
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1697
        try:
10✔
1698
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1699
        except WouldBlock:
9✔
1700
            event = asyncio.Event()
9✔
1701
            self._wait_queue[borrower] = event
9✔
1702
            try:
9✔
1703
                await event.wait()
9✔
UNCOV
1704
            except BaseException:
×
UNCOV
1705
                self._wait_queue.pop(borrower, None)
×
UNCOV
1706
                raise
×
1707

1708
            self._borrowers.add(borrower)
9✔
1709
        else:
1710
            try:
10✔
1711
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1712
            except BaseException:
9✔
1713
                self.release()
9✔
1714
                raise
9✔
1715

1716
    def release(self) -> None:
10✔
1717
        self.release_on_behalf_of(current_task())
10✔
1718

1719
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1720
        try:
10✔
1721
            self._borrowers.remove(borrower)
10✔
1722
        except KeyError:
9✔
1723
            raise RuntimeError(
9✔
1724
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1725
            ) from None
1726

1727
        # Notify the next task in line if this limiter has free capacity now
1728
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1729
            event = self._wait_queue.popitem(last=False)[1]
9✔
1730
            event.set()
9✔
1731

1732
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1733
        return CapacityLimiterStatistics(
9✔
1734
            self.borrowed_tokens,
1735
            self.total_tokens,
1736
            tuple(self._borrowers),
1737
            len(self._wait_queue),
1738
        )
1739

1740

1741
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1742

1743

1744
#
1745
# Operating system signals
1746
#
1747

1748

1749
class _SignalReceiver:
10✔
1750
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1751
        self._signals = signals
8✔
1752
        self._loop = get_running_loop()
8✔
1753
        self._signal_queue: deque[Signals] = deque()
8✔
1754
        self._future: asyncio.Future = asyncio.Future()
8✔
1755
        self._handled_signals: set[Signals] = set()
8✔
1756

1757
    def _deliver(self, signum: Signals) -> None:
10✔
1758
        self._signal_queue.append(signum)
8✔
1759
        if not self._future.done():
8✔
1760
            self._future.set_result(None)
8✔
1761

1762
    def __enter__(self) -> _SignalReceiver:
10✔
1763
        for sig in set(self._signals):
8✔
1764
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1765
            self._handled_signals.add(sig)
8✔
1766

1767
        return self
8✔
1768

1769
    def __exit__(
10✔
1770
        self,
1771
        exc_type: type[BaseException] | None,
1772
        exc_val: BaseException | None,
1773
        exc_tb: TracebackType | None,
1774
    ) -> bool | None:
1775
        for sig in self._handled_signals:
8✔
1776
            self._loop.remove_signal_handler(sig)
8✔
1777
        return None
8✔
1778

1779
    def __aiter__(self) -> _SignalReceiver:
10✔
1780
        return self
8✔
1781

1782
    async def __anext__(self) -> Signals:
10✔
1783
        await AsyncIOBackend.checkpoint()
8✔
1784
        if not self._signal_queue:
8✔
UNCOV
1785
            self._future = asyncio.Future()
×
UNCOV
1786
            await self._future
×
1787

1788
        return self._signal_queue.popleft()
8✔
1789

1790

1791
#
1792
# Testing and debugging
1793
#
1794

1795

1796
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1797
    task_state = _task_states.get(task)
10✔
1798
    if task_state is None:
10✔
1799
        parent_id = None
10✔
1800
    else:
1801
        parent_id = task_state.parent_id
10✔
1802

1803
    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
10✔
1804

1805

1806
class TestRunner(abc.TestRunner):
10✔
1807
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1808

1809
    def __init__(
10✔
1810
        self,
1811
        *,
1812
        debug: bool | None = None,
1813
        use_uvloop: bool = False,
1814
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
1815
    ) -> None:
1816
        if use_uvloop and loop_factory is None:
10✔
UNCOV
1817
            import uvloop
×
1818

UNCOV
1819
            loop_factory = uvloop.new_event_loop
×
1820

1821
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
1822
        self._exceptions: list[BaseException] = []
10✔
1823
        self._runner_task: asyncio.Task | None = None
10✔
1824

1825
    def __enter__(self) -> TestRunner:
10✔
1826
        self._runner.__enter__()
10✔
1827
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
1828
        return self
10✔
1829

1830
    def __exit__(
10✔
1831
        self,
1832
        exc_type: type[BaseException] | None,
1833
        exc_val: BaseException | None,
1834
        exc_tb: TracebackType | None,
1835
    ) -> None:
1836
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
1837

1838
    def get_loop(self) -> AbstractEventLoop:
10✔
1839
        return self._runner.get_loop()
10✔
1840

1841
    def _exception_handler(
10✔
1842
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1843
    ) -> None:
1844
        if isinstance(context.get("exception"), Exception):
10✔
1845
            self._exceptions.append(context["exception"])
10✔
1846
        else:
1847
            loop.default_exception_handler(context)
10✔
1848

1849
    def _raise_async_exceptions(self) -> None:
10✔
1850
        # Re-raise any exceptions raised in asynchronous callbacks
1851
        if self._exceptions:
10✔
1852
            exceptions, self._exceptions = self._exceptions, []
10✔
1853
            if len(exceptions) == 1:
10✔
1854
                raise exceptions[0]
10✔
UNCOV
1855
            elif exceptions:
×
UNCOV
1856
                raise BaseExceptionGroup(
×
1857
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1858
                )
1859

1860
    @staticmethod
10✔
1861
    async def _run_tests_and_fixtures(
10✔
1862
        receive_stream: MemoryObjectReceiveStream[
1863
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
1864
        ],
1865
    ) -> None:
1866
        with receive_stream:
10✔
1867
            async for coro, future in receive_stream:
10✔
1868
                try:
10✔
1869
                    retval = await coro
10✔
1870
                except BaseException as exc:
10✔
1871
                    if not future.cancelled():
10✔
1872
                        future.set_exception(exc)
10✔
1873
                else:
1874
                    if not future.cancelled():
10✔
1875
                        future.set_result(retval)
10✔
1876

1877
    async def _call_in_runner_task(
10✔
1878
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1879
    ) -> T_Retval:
1880
        if not self._runner_task:
10✔
1881
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
1882
                Tuple[Awaitable[Any], asyncio.Future]
1883
            ](1)
1884
            self._runner_task = self.get_loop().create_task(
10✔
1885
                self._run_tests_and_fixtures(receive_stream)
1886
            )
1887

1888
        coro = func(*args, **kwargs)
10✔
1889
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
1890
        self._send_stream.send_nowait((coro, future))
10✔
1891
        return await future
10✔
1892

1893
    def run_asyncgen_fixture(
10✔
1894
        self,
1895
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1896
        kwargs: dict[str, Any],
1897
    ) -> Iterable[T_Retval]:
1898
        asyncgen = fixture_func(**kwargs)
10✔
1899
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
1900
            self._call_in_runner_task(asyncgen.asend, None)
1901
        )
1902
        self._raise_async_exceptions()
10✔
1903

1904
        yield fixturevalue
10✔
1905

1906
        try:
10✔
1907
            self.get_loop().run_until_complete(
10✔
1908
                self._call_in_runner_task(asyncgen.asend, None)
1909
            )
1910
        except StopAsyncIteration:
10✔
1911
            self._raise_async_exceptions()
10✔
1912
        else:
UNCOV
1913
            self.get_loop().run_until_complete(asyncgen.aclose())
×
UNCOV
1914
            raise RuntimeError("Async generator fixture did not stop")
×
1915

1916
    def run_fixture(
10✔
1917
        self,
1918
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1919
        kwargs: dict[str, Any],
1920
    ) -> T_Retval:
1921
        retval = self.get_loop().run_until_complete(
10✔
1922
            self._call_in_runner_task(fixture_func, **kwargs)
1923
        )
1924
        self._raise_async_exceptions()
10✔
1925
        return retval
10✔
1926

1927
    def run_test(
10✔
1928
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1929
    ) -> None:
1930
        try:
10✔
1931
            self.get_loop().run_until_complete(
10✔
1932
                self._call_in_runner_task(test_func, **kwargs)
1933
            )
1934
        except Exception as exc:
10✔
UNCOV
1935
            self._exceptions.append(exc)
×
1936

1937
        self._raise_async_exceptions()
10✔
1938

1939

1940
class AsyncIOBackend(AsyncBackend):
10✔
1941
    @classmethod
10✔
1942
    def run(
10✔
1943
        cls,
1944
        func: Callable[..., Awaitable[T_Retval]],
1945
        args: tuple,
1946
        kwargs: dict[str, Any],
1947
        options: dict[str, Any],
1948
    ) -> T_Retval:
1949
        @wraps(func)
10✔
1950
        async def wrapper() -> T_Retval:
10✔
1951
            task = cast(asyncio.Task, current_task())
10✔
1952
            task.set_name(get_callable_name(func))
10✔
1953
            _task_states[task] = TaskState(None, None)
10✔
1954

1955
            try:
10✔
1956
                return await func(*args)
10✔
1957
            finally:
1958
                del _task_states[task]
10✔
1959

1960
        debug = options.get("debug", False)
10✔
1961
        options.get("loop_factory", None)
10✔
1962
        options.get("use_uvloop", False)
10✔
1963
        return native_run(wrapper(), debug=debug)
10✔
1964

1965
    @classmethod
10✔
1966
    def current_token(cls) -> object:
10✔
1967
        return get_running_loop()
10✔
1968

1969
    @classmethod
10✔
1970
    def current_time(cls) -> float:
10✔
1971
        return get_running_loop().time()
10✔
1972

1973
    @classmethod
10✔
1974
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
1975
        return CancelledError
10✔
1976

1977
    @classmethod
10✔
1978
    async def checkpoint(cls) -> None:
10✔
1979
        await sleep(0)
10✔
1980

1981
    @classmethod
10✔
1982
    async def checkpoint_if_cancelled(cls) -> None:
10✔
1983
        task = current_task()
10✔
1984
        if task is None:
10✔
UNCOV
1985
            return
×
1986

1987
        try:
10✔
1988
            cancel_scope = _task_states[task].cancel_scope
10✔
1989
        except KeyError:
10✔
1990
            return
10✔
1991

1992
        while cancel_scope:
10✔
1993
            if cancel_scope.cancel_called:
10✔
1994
                await sleep(0)
10✔
1995
            elif cancel_scope.shield:
10✔
1996
                break
9✔
1997
            else:
1998
                cancel_scope = cancel_scope._parent_scope
10✔
1999

2000
    @classmethod
10✔
2001
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
2002
        with CancelScope(shield=True):
10✔
2003
            await sleep(0)
10✔
2004

2005
    @classmethod
10✔
2006
    async def sleep(cls, delay: float) -> None:
10✔
2007
        await sleep(delay)
10✔
2008

2009
    @classmethod
10✔
2010
    def create_cancel_scope(
10✔
2011
        cls, *, deadline: float = math.inf, shield: bool = False
2012
    ) -> CancelScope:
2013
        return CancelScope(deadline=deadline, shield=shield)
10✔
2014

2015
    @classmethod
10✔
2016
    def current_effective_deadline(cls) -> float:
10✔
2017
        try:
9✔
2018
            cancel_scope = _task_states[
9✔
2019
                current_task()  # type: ignore[index]
2020
            ].cancel_scope
UNCOV
2021
        except KeyError:
×
UNCOV
2022
            return math.inf
×
2023

2024
        deadline = math.inf
9✔
2025
        while cancel_scope:
9✔
2026
            deadline = min(deadline, cancel_scope.deadline)
9✔
2027
            if cancel_scope._cancel_called:
9✔
2028
                deadline = -math.inf
9✔
2029
                break
9✔
2030
            elif cancel_scope.shield:
9✔
2031
                break
9✔
2032
            else:
2033
                cancel_scope = cancel_scope._parent_scope
9✔
2034

2035
        return deadline
9✔
2036

2037
    @classmethod
10✔
2038
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2039
        return TaskGroup()
10✔
2040

2041
    @classmethod
10✔
2042
    def create_event(cls) -> abc.Event:
10✔
2043
        return Event()
10✔
2044

2045
    @classmethod
10✔
2046
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2047
        return CapacityLimiter(total_tokens)
9✔
2048

2049
    @classmethod
10✔
2050
    async def run_sync_in_worker_thread(
10✔
2051
        cls,
2052
        func: Callable[..., T_Retval],
2053
        args: tuple[Any, ...],
2054
        cancellable: bool = False,
2055
        limiter: abc.CapacityLimiter | None = None,
2056
    ) -> T_Retval:
2057
        await cls.checkpoint()
10✔
2058

2059
        # If this is the first run in this event loop thread, set up the necessary
2060
        # variables
2061
        try:
10✔
2062
            idle_workers = _threadpool_idle_workers.get()
10✔
2063
            workers = _threadpool_workers.get()
10✔
2064
        except LookupError:
10✔
2065
            idle_workers = deque()
10✔
2066
            workers = set()
10✔
2067
            _threadpool_idle_workers.set(idle_workers)
10✔
2068
            _threadpool_workers.set(workers)
10✔
2069

2070
        async with (limiter or cls.current_default_thread_limiter()):
10✔
2071
            with CancelScope(shield=not cancellable):
10✔
2072
                future: asyncio.Future = asyncio.Future()
10✔
2073
                root_task = find_root_task()
10✔
2074
                if not idle_workers:
10✔
2075
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2076
                    worker.start()
10✔
2077
                    workers.add(worker)
10✔
2078
                    root_task.add_done_callback(worker.stop)
10✔
2079
                else:
2080
                    worker = idle_workers.pop()
10✔
2081

2082
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2083
                    # seconds or longer
2084
                    now = cls.current_time()
10✔
2085
                    while idle_workers:
10✔
2086
                        if (
9✔
2087
                            now - idle_workers[0].idle_since
2088
                            < WorkerThread.MAX_IDLE_TIME
2089
                        ):
2090
                            break
9✔
2091

UNCOV
2092
                        expired_worker = idle_workers.popleft()
×
UNCOV
2093
                        expired_worker.root_task.remove_done_callback(
×
2094
                            expired_worker.stop
2095
                        )
UNCOV
2096
                        expired_worker.stop()
×
2097

2098
                context = copy_context()
10✔
2099
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2100
                worker.queue.put_nowait((context, func, args, future))
10✔
2101
                return await future
10✔
2102

2103
    @classmethod
10✔
2104
    def run_async_from_thread(
10✔
2105
        cls,
2106
        func: Callable[..., Awaitable[T_Retval]],
2107
        args: tuple[Any, ...],
2108
        token: object,
2109
    ) -> T_Retval:
2110
        loop = cast(AbstractEventLoop, token)
10✔
2111
        context = copy_context()
10✔
2112
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2113
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2114
            asyncio.run_coroutine_threadsafe, func(*args), loop
2115
        )
2116
        return f.result()
10✔
2117

2118
    @classmethod
10✔
2119
    def run_sync_from_thread(
10✔
2120
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2121
    ) -> T_Retval:
2122
        @wraps(func)
10✔
2123
        def wrapper() -> None:
10✔
2124
            try:
10✔
2125
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2126
                f.set_result(func(*args))
10✔
2127
            except BaseException as exc:
10✔
2128
                f.set_exception(exc)
10✔
2129
                if not isinstance(exc, Exception):
10✔
UNCOV
2130
                    raise
×
2131

2132
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2133
        loop = cast(AbstractEventLoop, token)
10✔
2134
        loop.call_soon_threadsafe(wrapper)
10✔
2135
        return f.result()
10✔
2136

2137
    @classmethod
10✔
2138
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2139
        return BlockingPortal()
10✔
2140

2141
    @classmethod
10✔
2142
    async def open_process(
10✔
2143
        cls,
2144
        command: str | bytes | Sequence[str | bytes],
2145
        *,
2146
        shell: bool,
2147
        stdin: int | IO[Any] | None,
2148
        stdout: int | IO[Any] | None,
2149
        stderr: int | IO[Any] | None,
2150
        cwd: str | bytes | PathLike | None = None,
2151
        env: Mapping[str, str] | None = None,
2152
        start_new_session: bool = False,
2153
    ) -> Process:
2154
        await cls.checkpoint()
9✔
2155
        if shell:
9✔
2156
            process = await asyncio.create_subprocess_shell(
9✔
2157
                cast("str | bytes", command),
2158
                stdin=stdin,
2159
                stdout=stdout,
2160
                stderr=stderr,
2161
                cwd=cwd,
2162
                env=env,
2163
                start_new_session=start_new_session,
2164
            )
2165
        else:
2166
            process = await asyncio.create_subprocess_exec(
9✔
2167
                *command,
2168
                stdin=stdin,
2169
                stdout=stdout,
2170
                stderr=stderr,
2171
                cwd=cwd,
2172
                env=env,
2173
                start_new_session=start_new_session,
2174
            )
2175

2176
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2177
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2178
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2179
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2180

2181
    @classmethod
10✔
2182
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2183
        create_task(
9✔
2184
            _shutdown_process_pool_on_exit(workers),
2185
            name="AnyIO process pool shutdown task",
2186
        )
2187
        find_root_task().add_done_callback(
9✔
2188
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2189
        )
2190

2191
    @classmethod
10✔
2192
    async def connect_tcp(
10✔
2193
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2194
    ) -> abc.SocketStream:
2195
        transport, protocol = cast(
10✔
2196
            Tuple[asyncio.Transport, StreamProtocol],
2197
            await get_running_loop().create_connection(
2198
                StreamProtocol, host, port, local_addr=local_address
2199
            ),
2200
        )
2201
        transport.pause_reading()
10✔
2202
        return SocketStream(transport, protocol)
10✔
2203

2204
    @classmethod
10✔
2205
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
10✔
2206
        await cls.checkpoint()
7✔
2207
        loop = get_running_loop()
7✔
2208
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2209
        raw_socket.setblocking(False)
7✔
2210
        while True:
4✔
2211
            try:
7✔
2212
                raw_socket.connect(path)
7✔
2213
            except BlockingIOError:
7✔
UNCOV
2214
                f: asyncio.Future = asyncio.Future()
×
UNCOV
2215
                loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2216
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2217
                await f
×
2218
            except BaseException:
7✔
2219
                raw_socket.close()
7✔
2220
                raise
7✔
2221
            else:
2222
                return UNIXSocketStream(raw_socket)
7✔
2223

2224
    @classmethod
10✔
2225
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2226
        return TCPSocketListener(sock)
10✔
2227

2228
    @classmethod
10✔
2229
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2230
        return UNIXSocketListener(sock)
7✔
2231

2232
    @classmethod
10✔
2233
    async def create_udp_socket(
10✔
2234
        cls,
2235
        family: AddressFamily,
2236
        local_address: IPSockAddrType | None,
2237
        remote_address: IPSockAddrType | None,
2238
        reuse_port: bool,
2239
    ) -> UDPSocket | ConnectedUDPSocket:
2240
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2241
            DatagramProtocol,
2242
            local_addr=local_address,
2243
            remote_addr=remote_address,
2244
            family=family,
2245
            reuse_port=reuse_port,
2246
        )
2247
        if protocol.exception:
9✔
UNCOV
2248
            transport.close()
×
UNCOV
2249
            raise protocol.exception
×
2250

2251
        if not remote_address:
9✔
2252
            return UDPSocket(transport, protocol)
9✔
2253
        else:
2254
            return ConnectedUDPSocket(transport, protocol)
9✔
2255

2256
    @classmethod
10✔
2257
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2258
        cls, raw_socket: socket.socket, remote_path: str | None
2259
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2260
        await cls.checkpoint()
7✔
2261
        loop = get_running_loop()
7✔
2262

2263
        if remote_path:
7✔
2264
            while True:
4✔
2265
                try:
7✔
2266
                    raw_socket.connect(remote_path)
7✔
UNCOV
2267
                except BlockingIOError:
×
UNCOV
2268
                    f: asyncio.Future = asyncio.Future()
×
UNCOV
2269
                    loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2270
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2271
                    await f
×
UNCOV
2272
                except BaseException:
×
UNCOV
2273
                    raw_socket.close()
×
UNCOV
2274
                    raise
×
2275
                else:
2276
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2277
        else:
2278
            return UNIXDatagramSocket(raw_socket)
7✔
2279

2280
    @classmethod
10✔
2281
    async def getaddrinfo(
10✔
2282
        cls,
2283
        host: bytes | str | None,
2284
        port: str | int | None,
2285
        *,
2286
        family: int | AddressFamily = 0,
2287
        type: int | SocketKind = 0,
2288
        proto: int = 0,
2289
        flags: int = 0,
2290
    ) -> list[
2291
        tuple[
2292
            AddressFamily,
2293
            SocketKind,
2294
            int,
2295
            str,
2296
            tuple[str, int] | tuple[str, int, int, int],
2297
        ]
2298
    ]:
2299
        return await get_running_loop().getaddrinfo(
10✔
2300
            host, port, family=family, type=type, proto=proto, flags=flags
2301
        )
2302

2303
    @classmethod
10✔
2304
    async def getnameinfo(
10✔
2305
        cls, sockaddr: IPSockAddrType, flags: int = 0
2306
    ) -> tuple[str, str]:
2307
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2308

2309
    @classmethod
10✔
2310
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2311
        await cls.checkpoint()
×
2312
        try:
×
2313
            read_events = _read_events.get()
×
UNCOV
2314
        except LookupError:
×
2315
            read_events = {}
×
2316
            _read_events.set(read_events)
×
2317

UNCOV
2318
        if read_events.get(sock):
×
2319
            raise BusyResourceError("reading from") from None
×
2320

2321
        loop = get_running_loop()
×
2322
        event = read_events[sock] = asyncio.Event()
×
UNCOV
2323
        loop.add_reader(sock, event.set)
×
UNCOV
2324
        try:
×
UNCOV
2325
            await event.wait()
×
2326
        finally:
2327
            if read_events.pop(sock, None) is not None:
×
2328
                loop.remove_reader(sock)
×
2329
                readable = True
×
2330
            else:
2331
                readable = False
×
2332

2333
        if not readable:
×
2334
            raise ClosedResourceError
×
2335

2336
    @classmethod
10✔
2337
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2338
        await cls.checkpoint()
×
2339
        try:
×
2340
            write_events = _write_events.get()
×
UNCOV
2341
        except LookupError:
×
2342
            write_events = {}
×
2343
            _write_events.set(write_events)
×
2344

UNCOV
2345
        if write_events.get(sock):
×
2346
            raise BusyResourceError("writing to") from None
×
2347

2348
        loop = get_running_loop()
×
2349
        event = write_events[sock] = asyncio.Event()
×
UNCOV
2350
        loop.add_writer(sock.fileno(), event.set)
×
UNCOV
2351
        try:
×
UNCOV
2352
            await event.wait()
×
2353
        finally:
UNCOV
2354
            if write_events.pop(sock, None) is not None:
×
UNCOV
2355
                loop.remove_writer(sock)
×
UNCOV
2356
                writable = True
×
2357
            else:
UNCOV
2358
                writable = False
×
2359

UNCOV
2360
        if not writable:
×
UNCOV
2361
            raise ClosedResourceError
×
2362

2363
    @classmethod
10✔
2364
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2365
        try:
10✔
2366
            return _default_thread_limiter.get()
10✔
2367
        except LookupError:
10✔
2368
            limiter = CapacityLimiter(40)
10✔
2369
            _default_thread_limiter.set(limiter)
10✔
2370
            return limiter
10✔
2371

2372
    @classmethod
10✔
2373
    def open_signal_receiver(
10✔
2374
        cls, *signals: Signals
2375
    ) -> ContextManager[AsyncIterator[Signals]]:
2376
        return _SignalReceiver(signals)
8✔
2377

2378
    @classmethod
10✔
2379
    def get_current_task(cls) -> TaskInfo:
10✔
2380
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2381

2382
    @classmethod
10✔
2383
    def get_running_tasks(cls) -> list[TaskInfo]:
10✔
2384
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2385

2386
    @classmethod
10✔
2387
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2388
        await cls.checkpoint()
10✔
2389
        this_task = current_task()
10✔
2390
        while True:
6✔
2391
            for task in all_tasks():
10✔
2392
                if task is this_task:
10✔
2393
                    continue
10✔
2394

2395
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2396
                if waiter is None or waiter.done():
10✔
2397
                    await sleep(0.1)
10✔
2398
                    break
10✔
2399
            else:
2400
                return
10✔
2401

2402
    @classmethod
10✔
2403
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2404
        return TestRunner(**options)
10✔
2405

2406

2407
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc