• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 5643411386

pending completion
5643411386

Pull #591

github

web-flow
Merge 8f06e6cc6 into 9a464bf68
Pull Request #591: Improved detection of timeout based vs direct cancellation of cancel scopes

29 of 30 new or added lines in 3 files covered. (96.67%)

96 existing lines in 1 file now uncovered.

4251 of 4716 relevant lines covered (90.14%)

8.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.89
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextvars import Context, copy_context
10✔
25
from dataclasses import dataclass
10✔
26
from functools import partial, wraps
10✔
27
from inspect import (
10✔
28
    CORO_RUNNING,
29
    CORO_SUSPENDED,
30
    getcoroutinestate,
31
    iscoroutine,
32
)
33
from io import IOBase
10✔
34
from os import PathLike
10✔
35
from queue import Queue
10✔
36
from signal import Signals
10✔
37
from socket import AddressFamily, SocketKind
10✔
38
from threading import Thread
10✔
39
from types import TracebackType
10✔
40
from typing import (
10✔
41
    IO,
42
    Any,
43
    AsyncGenerator,
44
    Awaitable,
45
    Callable,
46
    Collection,
47
    ContextManager,
48
    Coroutine,
49
    Mapping,
50
    Optional,
51
    Sequence,
52
    Tuple,
53
    TypeVar,
54
    cast,
55
)
56
from weakref import WeakKeyDictionary
10✔
57

58
import sniffio
10✔
59

60
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
61
from .._core._eventloop import claim_worker_thread
10✔
62
from .._core._exceptions import (
10✔
63
    BrokenResourceError,
64
    BusyResourceError,
65
    ClosedResourceError,
66
    EndOfStream,
67
    WouldBlock,
68
)
69
from .._core._sockets import convert_ipv6_sockaddr
10✔
70
from .._core._streams import create_memory_object_stream
10✔
71
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
72
from .._core._synchronization import Event as BaseEvent
10✔
73
from .._core._synchronization import ResourceGuard
10✔
74
from .._core._tasks import CancelScope as BaseCancelScope
10✔
75
from ..abc import (
10✔
76
    AsyncBackend,
77
    IPSockAddrType,
78
    SocketListener,
79
    UDPPacketType,
80
    UNIXDatagramPacketType,
81
)
82
from ..lowlevel import RunVar
10✔
83
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
84

85
if sys.version_info >= (3, 11):
10✔
86
    from asyncio import Runner
5✔
87
else:
88
    import contextvars
6✔
89
    import enum
6✔
90
    import signal
6✔
91
    from asyncio import coroutines, events, exceptions, tasks
6✔
92

93
    from exceptiongroup import BaseExceptionGroup
6✔
94

95
    class _State(enum.Enum):
6✔
96
        CREATED = "created"
6✔
97
        INITIALIZED = "initialized"
6✔
98
        CLOSED = "closed"
6✔
99

100
    class Runner:
6✔
101
        # Copied from CPython 3.11
102
        def __init__(
6✔
103
            self,
104
            *,
105
            debug: bool | None = None,
106
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
107
        ):
108
            self._state = _State.CREATED
6✔
109
            self._debug = debug
6✔
110
            self._loop_factory = loop_factory
6✔
111
            self._loop: AbstractEventLoop | None = None
6✔
112
            self._context = None
6✔
113
            self._interrupt_count = 0
6✔
114
            self._set_event_loop = False
6✔
115

116
        def __enter__(self) -> Runner:
6✔
117
            self._lazy_init()
6✔
118
            return self
6✔
119

120
        def __exit__(
6✔
121
            self,
122
            exc_type: type[BaseException],
123
            exc_val: BaseException,
124
            exc_tb: TracebackType,
125
        ) -> None:
126
            self.close()
6✔
127

128
        def close(self) -> None:
6✔
129
            """Shutdown and close event loop."""
130
            if self._state is not _State.INITIALIZED:
6✔
131
                return
×
132
            try:
6✔
133
                loop = self._loop
6✔
134
                _cancel_all_tasks(loop)
6✔
135
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
136
                if hasattr(loop, "shutdown_default_executor"):
6✔
137
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
138
                else:
139
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
140
            finally:
141
                if self._set_event_loop:
6✔
142
                    events.set_event_loop(None)
6✔
143
                loop.close()
6✔
144
                self._loop = None
6✔
145
                self._state = _State.CLOSED
6✔
146

147
        def get_loop(self) -> AbstractEventLoop:
6✔
148
            """Return embedded event loop."""
149
            self._lazy_init()
6✔
150
            return self._loop
6✔
151

152
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
153
            """Run a coroutine inside the embedded event loop."""
154
            if not coroutines.iscoroutine(coro):
×
155
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
156

157
            if events._get_running_loop() is not None:
×
158
                # fail fast with short traceback
159
                raise RuntimeError(
×
160
                    "Runner.run() cannot be called from a running event loop"
161
                )
162

163
            self._lazy_init()
×
164

165
            if context is None:
×
166
                context = self._context
×
167
            task = self._loop.create_task(coro, context=context)
×
168

169
            if (
×
170
                threading.current_thread() is threading.main_thread()
171
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
172
            ):
173
                sigint_handler = partial(self._on_sigint, main_task=task)
×
174
                try:
×
175
                    signal.signal(signal.SIGINT, sigint_handler)
×
176
                except ValueError:
×
177
                    # `signal.signal` may throw if `threading.main_thread` does
178
                    # not support signals (e.g. embedded interpreter with signals
179
                    # not registered - see gh-91880)
180
                    sigint_handler = None
×
181
            else:
182
                sigint_handler = None
×
183

184
            self._interrupt_count = 0
×
185
            try:
×
186
                return self._loop.run_until_complete(task)
×
187
            except exceptions.CancelledError:
×
188
                if self._interrupt_count > 0:
×
189
                    uncancel = getattr(task, "uncancel", None)
×
190
                    if uncancel is not None and uncancel() == 0:
×
191
                        raise KeyboardInterrupt()
×
192
                raise  # CancelledError
×
193
            finally:
194
                if (
×
195
                    sigint_handler is not None
196
                    and signal.getsignal(signal.SIGINT) is sigint_handler
197
                ):
198
                    signal.signal(signal.SIGINT, signal.default_int_handler)
×
199

200
        def _lazy_init(self) -> None:
6✔
201
            if self._state is _State.CLOSED:
6✔
202
                raise RuntimeError("Runner is closed")
×
203
            if self._state is _State.INITIALIZED:
6✔
204
                return
6✔
205
            if self._loop_factory is None:
6✔
206
                self._loop = events.new_event_loop()
6✔
207
                if not self._set_event_loop:
6✔
208
                    # Call set_event_loop only once to avoid calling
209
                    # attach_loop multiple times on child watchers
210
                    events.set_event_loop(self._loop)
6✔
211
                    self._set_event_loop = True
6✔
212
            else:
213
                self._loop = self._loop_factory()
4✔
214
            if self._debug is not None:
6✔
215
                self._loop.set_debug(self._debug)
6✔
216
            self._context = contextvars.copy_context()
6✔
217
            self._state = _State.INITIALIZED
6✔
218

219
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
220
            self._interrupt_count += 1
×
221
            if self._interrupt_count == 1 and not main_task.done():
×
222
                main_task.cancel()
×
223
                # wakeup loop if it is blocked by select() with long timeout
224
                self._loop.call_soon_threadsafe(lambda: None)
×
225
                return
×
226
            raise KeyboardInterrupt()
×
227

228
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
229
        to_cancel = tasks.all_tasks(loop)
6✔
230
        if not to_cancel:
6✔
231
            return
×
232

233
        for task in to_cancel:
6✔
234
            task.cancel()
6✔
235

236
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
237

238
        for task in to_cancel:
6✔
239
            if task.cancelled():
6✔
240
                continue
6✔
241
            if task.exception() is not None:
5✔
242
                loop.call_exception_handler(
2✔
243
                    {
244
                        "message": "unhandled exception during asyncio.run() shutdown",
245
                        "exception": task.exception(),
246
                        "task": task,
247
                    }
248
                )
249

250
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
251
        """Schedule the shutdown of the default executor."""
252

253
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
254
            try:
3✔
255
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
256
                loop.call_soon_threadsafe(future.set_result, None)
3✔
257
            except Exception as ex:
×
258
                loop.call_soon_threadsafe(future.set_exception, ex)
×
259

260
        loop._executor_shutdown_called = True
3✔
261
        if loop._default_executor is None:
3✔
262
            return
3✔
263
        future = loop.create_future()
3✔
264
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
265
        thread.start()
3✔
266
        try:
3✔
267
            await future
3✔
268
        finally:
269
            thread.join()
3✔
270

271

272
T_Retval = TypeVar("T_Retval")
10✔
273
T_contra = TypeVar("T_contra", contravariant=True)
10✔
274

275
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
276

277

278
def find_root_task() -> asyncio.Task:
10✔
279
    root_task = _root_task.get(None)
10✔
280
    if root_task is not None and not root_task.done():
10✔
281
        return root_task
10✔
282

283
    # Look for a task that has been started via run_until_complete()
284
    for task in all_tasks():
10✔
285
        if task._callbacks and not task.done():
10✔
286
            callbacks = [cb for cb, context in task._callbacks]
10✔
287
            for cb in callbacks:
10✔
288
                if (
10✔
289
                    cb is _run_until_complete_cb
290
                    or getattr(cb, "__module__", None) == "uvloop.loop"
291
                ):
292
                    _root_task.set(task)
10✔
293
                    return task
10✔
294

295
    # Look up the topmost task in the AnyIO task tree, if possible
296
    task = cast(asyncio.Task, current_task())
9✔
297
    state = _task_states.get(task)
9✔
298
    if state:
9✔
299
        cancel_scope = state.cancel_scope
9✔
300
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
301
            cancel_scope = cancel_scope._parent_scope
×
302

303
        if cancel_scope is not None:
9✔
304
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
305

306
    return task
×
307

308

309
def get_callable_name(func: Callable) -> str:
10✔
310
    module = getattr(func, "__module__", None)
10✔
311
    qualname = getattr(func, "__qualname__", None)
10✔
312
    return ".".join([x for x in (module, qualname) if x])
10✔
313

314

315
#
316
# Event loop
317
#
318

319
_run_vars = (
10✔
320
    WeakKeyDictionary()
321
)  # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
322

323

324
def _task_started(task: asyncio.Task) -> bool:
10✔
325
    """Return ``True`` if the task has been started and has not finished."""
326
    try:
10✔
327
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
328
    except AttributeError:
×
329
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
330
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
331

332

333
#
334
# Timeouts and cancellation
335
#
336

337

338
class CancelScope(BaseCancelScope):
10✔
339
    def __new__(
10✔
340
        cls, *, deadline: float = math.inf, shield: bool = False
341
    ) -> CancelScope:
342
        return object.__new__(cls)
10✔
343

344
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
345
        self._deadline = deadline
10✔
346
        self._shield = shield
10✔
347
        self._parent_scope: CancelScope | None = None
10✔
348
        self._cancel_called = False
10✔
349
        self._cancelled_caught = False
10✔
350
        self._active = False
10✔
351
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
352
        self._cancel_handle: asyncio.Handle | None = None
10✔
353
        self._tasks: set[asyncio.Task] = set()
10✔
354
        self._host_task: asyncio.Task | None = None
10✔
355
        self._timeout_expired = False
10✔
356
        self._cancel_calls: int = 0
10✔
357

358
    def __enter__(self) -> CancelScope:
10✔
359
        if self._active:
10✔
UNCOV
360
            raise RuntimeError(
×
361
                "Each CancelScope may only be used for a single 'with' block"
362
            )
363

364
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
365
        self._tasks.add(host_task)
10✔
366
        try:
10✔
367
            task_state = _task_states[host_task]
10✔
368
        except KeyError:
10✔
369
            task_state = TaskState(None, self)
10✔
370
            _task_states[host_task] = task_state
10✔
371
        else:
372
            self._parent_scope = task_state.cancel_scope
10✔
373
            task_state.cancel_scope = self
10✔
374

375
        self._timeout()
10✔
376
        self._active = True
10✔
377

378
        # Start cancelling the host task if the scope was cancelled before entering
379
        if self._cancel_called:
10✔
380
            self._deliver_cancellation()
10✔
381

382
        return self
10✔
383

384
    def __exit__(
10✔
385
        self,
386
        exc_type: type[BaseException] | None,
387
        exc_val: BaseException | None,
388
        exc_tb: TracebackType | None,
389
    ) -> bool | None:
390
        if not self._active:
10✔
391
            raise RuntimeError("This cancel scope is not active")
9✔
392
        if current_task() is not self._host_task:
10✔
393
            raise RuntimeError(
9✔
394
                "Attempted to exit cancel scope in a different task than it was "
395
                "entered in"
396
            )
397

398
        assert self._host_task is not None
10✔
399
        host_task_state = _task_states.get(self._host_task)
10✔
400
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
401
            raise RuntimeError(
9✔
402
                "Attempted to exit a cancel scope that isn't the current tasks's "
403
                "current cancel scope"
404
            )
405

406
        self._active = False
10✔
407
        if self._timeout_handle:
10✔
408
            self._timeout_handle.cancel()
10✔
409
            self._timeout_handle = None
10✔
410

411
        self._tasks.remove(self._host_task)
10✔
412

413
        host_task_state.cancel_scope = self._parent_scope
10✔
414

415
        # Restart the cancellation effort in the farthest directly cancelled parent
416
        # scope if this one was shielded
417
        if self._shield:
10✔
418
            self._deliver_cancellation_to_parent()
10✔
419

420
        if isinstance(exc_val, CancelledError):
10✔
421
            if self._timeout_expired:
10✔
422
                self._cancelled_caught = self._uncancel()
10✔
423
                return self._cancelled_caught
10✔
424
            elif not self._cancel_called:
10✔
425
                # Task was cancelled natively
426
                return None
10✔
427
            elif not self._parent_cancelled():
10✔
428
                # This scope was directly cancelled
429
                self._cancelled_caught = self._uncancel()
10✔
430
                return self._cancelled_caught
10✔
431

432
        return None
10✔
433

434
    def _uncancel(self) -> bool:
10✔
435
        if sys.version_info < (3, 11) or self._host_task is None:
10✔
436
            self._cancel_calls = 0
6✔
437
            return True
6✔
438

439
        # Uncancel all AnyIO cancellations
440
        for i in range(self._cancel_calls):
4✔
441
            self._host_task.uncancel()
4✔
442

443
        self._cancel_calls = 0
4✔
444
        return not self._host_task.cancelling()
4✔
445

446
    def _timeout(self) -> None:
10✔
447
        if self._deadline != math.inf:
10✔
448
            loop = get_running_loop()
10✔
449
            if loop.time() >= self._deadline:
10✔
450
                self._timeout_expired = True
10✔
451
                self.cancel()
10✔
452
            else:
453
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
454

455
    def _deliver_cancellation(self) -> None:
10✔
456
        """
457
        Deliver cancellation to directly contained tasks and nested cancel scopes.
458

459
        Schedule another run at the end if we still have tasks eligible for
460
        cancellation.
461
        """
462
        should_retry = False
10✔
463
        current = current_task()
10✔
464
        for task in self._tasks:
10✔
465
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
466
                continue
9✔
467

468
            # The task is eligible for cancellation if it has started and is not in a
469
            # cancel scope shielded from this one
470
            cancel_scope = _task_states[task].cancel_scope
10✔
471
            while cancel_scope is not self:
10✔
472
                if cancel_scope is None or cancel_scope._shield:
10✔
473
                    break
10✔
474
                else:
475
                    cancel_scope = cancel_scope._parent_scope
10✔
476
            else:
477
                should_retry = True
10✔
478
                if task is not current and (
10✔
479
                    task is self._host_task or _task_started(task)
480
                ):
481
                    self._cancel_calls += 1
10✔
482
                    task.cancel()
10✔
483

484
        # Schedule another callback if there are still tasks left
485
        if should_retry:
10✔
486
            self._cancel_handle = get_running_loop().call_soon(
10✔
487
                self._deliver_cancellation
488
            )
489
        else:
490
            self._cancel_handle = None
10✔
491

492
    def _deliver_cancellation_to_parent(self) -> None:
10✔
493
        """Start cancellation effort in the farthest directly cancelled parent scope"""
494
        scope = self._parent_scope
10✔
495
        scope_to_cancel: CancelScope | None = None
10✔
496
        while scope is not None:
10✔
497
            if scope._cancel_called and scope._cancel_handle is None:
10✔
498
                scope_to_cancel = scope
9✔
499

500
            # No point in looking beyond any shielded scope
501
            if scope._shield:
10✔
502
                break
9✔
503

504
            scope = scope._parent_scope
10✔
505

506
        if scope_to_cancel is not None:
10✔
507
            scope_to_cancel._deliver_cancellation()
9✔
508

509
    def _parent_cancelled(self) -> bool:
10✔
510
        # Check whether any parent has been cancelled
511
        cancel_scope = self._parent_scope
10✔
512
        while cancel_scope is not None and not cancel_scope._shield:
10✔
513
            if cancel_scope._cancel_called:
10✔
514
                return True
9✔
515
            else:
516
                cancel_scope = cancel_scope._parent_scope
10✔
517

518
        return False
10✔
519

520
    def cancel(self) -> None:
10✔
521
        if not self._cancel_called:
10✔
522
            if self._timeout_handle:
10✔
523
                self._timeout_handle.cancel()
10✔
524
                self._timeout_handle = None
10✔
525

526
            self._cancel_called = True
10✔
527
            if self._host_task is not None:
10✔
528
                self._deliver_cancellation()
10✔
529

530
    @property
10✔
531
    def deadline(self) -> float:
10✔
532
        return self._deadline
9✔
533

534
    @deadline.setter
10✔
535
    def deadline(self, value: float) -> None:
10✔
536
        self._deadline = float(value)
9✔
537
        if self._timeout_handle is not None:
9✔
538
            self._timeout_handle.cancel()
9✔
539
            self._timeout_handle = None
9✔
540

541
        if self._active and not self._cancel_called:
9✔
542
            self._timeout()
9✔
543

544
    @property
10✔
545
    def cancel_called(self) -> bool:
10✔
546
        return self._cancel_called
10✔
547

548
    @property
10✔
549
    def cancelled_caught(self) -> bool:
10✔
550
        return self._cancelled_caught
10✔
551

552
    @property
10✔
553
    def shield(self) -> bool:
10✔
554
        return self._shield
10✔
555

556
    @shield.setter
10✔
557
    def shield(self, value: bool) -> None:
10✔
558
        if self._shield != value:
9✔
559
            self._shield = value
9✔
560
            if not value:
9✔
561
                self._deliver_cancellation_to_parent()
9✔
562

563

564
#
565
# Task states
566
#
567

568

569
class TaskState:
10✔
570
    """
571
    Encapsulates auxiliary task information that cannot be added to the Task instance
572
    itself because there are no guarantees about its implementation.
573
    """
574

575
    __slots__ = "parent_id", "cancel_scope"
10✔
576

577
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
578
        self.parent_id = parent_id
10✔
579
        self.cancel_scope = cancel_scope
10✔
580

581

582
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
583

584

585
#
586
# Task groups
587
#
588

589

590
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
591
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
592
        self._future = future
10✔
593
        self._parent_id = parent_id
10✔
594

595
    def started(self, value: T_contra | None = None) -> None:
10✔
596
        try:
10✔
597
            self._future.set_result(value)
10✔
598
        except asyncio.InvalidStateError:
9✔
599
            raise RuntimeError(
9✔
600
                "called 'started' twice on the same task status"
601
            ) from None
602

603
        task = cast(asyncio.Task, current_task())
10✔
604
        _task_states[task].parent_id = self._parent_id
10✔
605

606

607
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
608
    exceptions = list(excgroup.exceptions)
×
609
    modified = False
×
610
    for i, exc in enumerate(exceptions):
×
UNCOV
611
        if isinstance(exc, BaseExceptionGroup):
×
612
            new_exc = collapse_exception_group(exc)
×
UNCOV
613
            if new_exc is not exc:
×
UNCOV
614
                modified = True
×
UNCOV
615
                exceptions[i] = new_exc
×
616

UNCOV
617
    if len(exceptions) == 1:
×
UNCOV
618
        return exceptions[0]
×
UNCOV
619
    elif modified:
×
UNCOV
620
        return excgroup.derive(exceptions)
×
621
    else:
UNCOV
622
        return excgroup
×
623

624

625
class TaskGroup(abc.TaskGroup):
10✔
626
    def __init__(self) -> None:
10✔
627
        self.cancel_scope: CancelScope = CancelScope()
10✔
628
        self._active = False
10✔
629
        self._exceptions: list[BaseException] = []
10✔
630

631
    async def __aenter__(self) -> TaskGroup:
10✔
632
        self.cancel_scope.__enter__()
10✔
633
        self._active = True
10✔
634
        return self
10✔
635

636
    async def __aexit__(
10✔
637
        self,
638
        exc_type: type[BaseException] | None,
639
        exc_val: BaseException | None,
640
        exc_tb: TracebackType | None,
641
    ) -> bool | None:
642
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
643
        if exc_val is not None:
10✔
644
            self.cancel_scope.cancel()
10✔
645
            if not isinstance(exc_val, CancelledError):
10✔
646
                self._exceptions.append(exc_val)
10✔
647

648
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
10✔
649
        waited_for_tasks_to_finish = bool(self.cancel_scope._tasks)
10✔
650
        while self.cancel_scope._tasks:
10✔
651
            try:
10✔
652
                await asyncio.wait(self.cancel_scope._tasks)
10✔
653
            except CancelledError as exc:
9✔
654
                # This task was cancelled natively; reraise the CancelledError later
655
                # unless this task was already interrupted by another exception
656
                self.cancel_scope.cancel()
9✔
657
                if cancelled_exc_while_waiting_tasks is None:
9✔
658
                    cancelled_exc_while_waiting_tasks = exc
9✔
659

660
        self._active = False
10✔
661
        if self._exceptions:
10✔
662
            raise BaseExceptionGroup(
10✔
663
                "unhandled errors in a TaskGroup", self._exceptions
664
            )
665

666
        # Raise the CancelledError received while waiting for child tasks to exit,
667
        # unless the context manager itself was previously exited with another
668
        # exception, or if any of the  child tasks raised an exception other than
669
        # CancelledError
670
        if cancelled_exc_while_waiting_tasks:
10✔
671
            if exc_val is None or ignore_exception:
9✔
672
                raise cancelled_exc_while_waiting_tasks
9✔
673

674
        # Yield control to the event loop here to ensure that there is at least one
675
        # yield point within __aexit__() (trio does the same)
676
        if not waited_for_tasks_to_finish:
10✔
677
            await AsyncIOBackend.checkpoint()
10✔
678

679
        return ignore_exception
10✔
680

681
    def _spawn(
10✔
682
        self,
683
        func: Callable[..., Awaitable[Any]],
684
        args: tuple,
685
        name: object,
686
        task_status_future: asyncio.Future | None = None,
687
    ) -> asyncio.Task:
688
        def task_done(_task: asyncio.Task) -> None:
10✔
689
            assert _task in self.cancel_scope._tasks
10✔
690
            self.cancel_scope._tasks.remove(_task)
10✔
691
            del _task_states[_task]
10✔
692

693
            try:
10✔
694
                exc = _task.exception()
10✔
695
            except CancelledError as e:
10✔
696
                while isinstance(e.__context__, CancelledError):
10✔
697
                    e = e.__context__
7✔
698

699
                exc = e
10✔
700

701
            if exc is not None:
10✔
702
                if task_status_future is None or task_status_future.done():
10✔
703
                    if not isinstance(exc, CancelledError):
10✔
704
                        self._exceptions.append(exc)
10✔
705

706
                    self.cancel_scope.cancel()
10✔
707
                else:
708
                    task_status_future.set_exception(exc)
9✔
709
            elif task_status_future is not None and not task_status_future.done():
10✔
710
                task_status_future.set_exception(
9✔
711
                    RuntimeError("Child exited without calling task_status.started()")
712
                )
713

714
        if not self._active:
10✔
715
            raise RuntimeError(
10✔
716
                "This task group is not active; no new tasks can be started."
717
            )
718

719
        kwargs = {}
10✔
720
        if task_status_future:
10✔
721
            parent_id = id(current_task())
10✔
722
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
723
                task_status_future, id(self.cancel_scope._host_task)
724
            )
725
        else:
726
            parent_id = id(self.cancel_scope._host_task)
10✔
727

728
        coro = func(*args, **kwargs)
10✔
729
        if not iscoroutine(coro):
10✔
730
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
731
            raise TypeError(
9✔
732
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
733
                f"the return value ({coro!r}) is not a coroutine object"
734
            )
735

736
        name = get_callable_name(func) if name is None else str(name)
10✔
737
        task = create_task(coro, name=name)
10✔
738
        task.add_done_callback(task_done)
10✔
739

740
        # Make the spawned task inherit the task group's cancel scope
741
        _task_states[task] = TaskState(
10✔
742
            parent_id=parent_id, cancel_scope=self.cancel_scope
743
        )
744
        self.cancel_scope._tasks.add(task)
10✔
745
        return task
10✔
746

747
    def start_soon(
10✔
748
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
749
    ) -> None:
750
        self._spawn(func, args, name)
10✔
751

752
    async def start(
10✔
753
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
754
    ) -> None:
755
        future: asyncio.Future = asyncio.Future()
10✔
756
        task = self._spawn(func, args, name, future)
10✔
757

758
        # If the task raises an exception after sending a start value without a switch
759
        # point between, the task group is cancelled and this method never proceeds to
760
        # process the completed future. That's why we have to have a shielded cancel
761
        # scope here.
762
        with CancelScope(shield=True):
10✔
763
            try:
10✔
764
                return await future
10✔
765
            except CancelledError:
9✔
766
                task.cancel()
9✔
767
                raise
9✔
768

769

770
#
771
# Threads
772
#
773

774
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
775

776

777
class WorkerThread(Thread):
10✔
778
    MAX_IDLE_TIME = 10  # seconds
10✔
779

780
    def __init__(
10✔
781
        self,
782
        root_task: asyncio.Task,
783
        workers: set[WorkerThread],
784
        idle_workers: deque[WorkerThread],
785
    ):
786
        super().__init__(name="AnyIO worker thread")
10✔
787
        self.root_task = root_task
10✔
788
        self.workers = workers
10✔
789
        self.idle_workers = idle_workers
10✔
790
        self.loop = root_task._loop
10✔
791
        self.queue: Queue[
10✔
792
            tuple[Context, Callable, tuple, asyncio.Future] | None
793
        ] = Queue(2)
794
        self.idle_since = AsyncIOBackend.current_time()
10✔
795
        self.stopping = False
10✔
796

797
    def _report_result(
10✔
798
        self, future: asyncio.Future, result: Any, exc: BaseException | None
799
    ) -> None:
800
        self.idle_since = AsyncIOBackend.current_time()
10✔
801
        if not self.stopping:
10✔
802
            self.idle_workers.append(self)
10✔
803

804
        if not future.cancelled():
10✔
805
            if exc is not None:
10✔
806
                if isinstance(exc, StopIteration):
10✔
807
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
808
                    new_exc.__cause__ = exc
9✔
809
                    exc = new_exc
9✔
810

811
                future.set_exception(exc)
10✔
812
            else:
813
                future.set_result(result)
10✔
814

815
    def run(self) -> None:
10✔
816
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
817
            while True:
6✔
818
                item = self.queue.get()
10✔
819
                if item is None:
10✔
820
                    # Shutdown command received
821
                    return
10✔
822

823
                context, func, args, future = item
10✔
824
                if not future.cancelled():
10✔
825
                    result = None
10✔
826
                    exception: BaseException | None = None
10✔
827
                    try:
10✔
828
                        result = context.run(func, *args)
10✔
829
                    except BaseException as exc:
10✔
830
                        exception = exc
10✔
831

832
                    if not self.loop.is_closed():
10✔
833
                        self.loop.call_soon_threadsafe(
10✔
834
                            self._report_result, future, result, exception
835
                        )
836

837
                self.queue.task_done()
10✔
838

839
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
840
        self.stopping = True
10✔
841
        self.queue.put_nowait(None)
10✔
842
        self.workers.discard(self)
10✔
843
        try:
10✔
844
            self.idle_workers.remove(self)
10✔
845
        except ValueError:
9✔
846
            pass
9✔
847

848

849
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
850
    "_threadpool_idle_workers"
851
)
852
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
853

854

855
class BlockingPortal(abc.BlockingPortal):
10✔
856
    def __new__(cls) -> BlockingPortal:
10✔
857
        return object.__new__(cls)
10✔
858

859
    def __init__(self) -> None:
10✔
860
        super().__init__()
10✔
861
        self._loop = get_running_loop()
10✔
862

863
    def _spawn_task_from_thread(
10✔
864
        self,
865
        func: Callable,
866
        args: tuple[Any, ...],
867
        kwargs: dict[str, Any],
868
        name: object,
869
        future: Future,
870
    ) -> None:
871
        AsyncIOBackend.run_sync_from_thread(
10✔
872
            partial(self._task_group.start_soon, name=name),
873
            (self._call_func, func, args, kwargs, future),
874
            self._loop,
875
        )
876

877

878
#
879
# Subprocesses
880
#
881

882

883
@dataclass(eq=False)
10✔
884
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
885
    _stream: asyncio.StreamReader
10✔
886

887
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
888
        data = await self._stream.read(max_bytes)
9✔
889
        if data:
9✔
890
            return data
9✔
891
        else:
892
            raise EndOfStream
9✔
893

894
    async def aclose(self) -> None:
10✔
895
        self._stream.feed_eof()
9✔
896

897

898
@dataclass(eq=False)
10✔
899
class StreamWriterWrapper(abc.ByteSendStream):
10✔
900
    _stream: asyncio.StreamWriter
10✔
901

902
    async def send(self, item: bytes) -> None:
10✔
903
        self._stream.write(item)
9✔
904
        await self._stream.drain()
9✔
905

906
    async def aclose(self) -> None:
10✔
907
        self._stream.close()
9✔
908

909

910
@dataclass(eq=False)
10✔
911
class Process(abc.Process):
10✔
912
    _process: asyncio.subprocess.Process
10✔
913
    _stdin: StreamWriterWrapper | None
10✔
914
    _stdout: StreamReaderWrapper | None
10✔
915
    _stderr: StreamReaderWrapper | None
10✔
916

917
    async def aclose(self) -> None:
10✔
918
        if self._stdin:
9✔
919
            await self._stdin.aclose()
9✔
920
        if self._stdout:
9✔
921
            await self._stdout.aclose()
9✔
922
        if self._stderr:
9✔
923
            await self._stderr.aclose()
9✔
924

925
        await self.wait()
9✔
926

927
    async def wait(self) -> int:
10✔
928
        return await self._process.wait()
9✔
929

930
    def terminate(self) -> None:
10✔
931
        self._process.terminate()
7✔
932

933
    def kill(self) -> None:
10✔
934
        self._process.kill()
9✔
935

936
    def send_signal(self, signal: int) -> None:
10✔
UNCOV
937
        self._process.send_signal(signal)
×
938

939
    @property
10✔
940
    def pid(self) -> int:
10✔
UNCOV
941
        return self._process.pid
×
942

943
    @property
10✔
944
    def returncode(self) -> int | None:
10✔
945
        return self._process.returncode
9✔
946

947
    @property
10✔
948
    def stdin(self) -> abc.ByteSendStream | None:
10✔
949
        return self._stdin
9✔
950

951
    @property
10✔
952
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
953
        return self._stdout
9✔
954

955
    @property
10✔
956
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
957
        return self._stderr
9✔
958

959

960
def _forcibly_shutdown_process_pool_on_exit(
10✔
961
    workers: set[Process], _task: object
962
) -> None:
963
    """
964
    Forcibly shuts down worker processes belonging to this event loop."""
965
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
966
    if sys.version_info < (3, 12):
9✔
967
        try:
8✔
968
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
969
        except NotImplementedError:
2✔
970
            pass
2✔
971

972
    # Close as much as possible (w/o async/await) to avoid warnings
973
    for process in workers:
9✔
974
        if process.returncode is None:
9✔
975
            continue
9✔
976

UNCOV
977
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
978
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
979
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
980
        process.kill()
×
UNCOV
981
        if child_watcher:
×
UNCOV
982
            child_watcher.remove_child_handler(process.pid)
×
983

984

985
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
986
    """
987
    Shuts down worker processes belonging to this event loop.
988

989
    NOTE: this only works when the event loop was started using asyncio.run() or
990
    anyio.run().
991

992
    """
993
    process: abc.Process
994
    try:
9✔
995
        await sleep(math.inf)
9✔
996
    except asyncio.CancelledError:
9✔
997
        for process in workers:
9✔
998
            if process.returncode is None:
9✔
999
                process.kill()
9✔
1000

1001
        for process in workers:
9✔
1002
            await process.aclose()
9✔
1003

1004

1005
#
1006
# Sockets and networking
1007
#
1008

1009

1010
class StreamProtocol(asyncio.Protocol):
10✔
1011
    read_queue: deque[bytes]
10✔
1012
    read_event: asyncio.Event
10✔
1013
    write_event: asyncio.Event
10✔
1014
    exception: Exception | None = None
10✔
1015

1016
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1017
        self.read_queue = deque()
10✔
1018
        self.read_event = asyncio.Event()
10✔
1019
        self.write_event = asyncio.Event()
10✔
1020
        self.write_event.set()
10✔
1021
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1022

1023
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1024
        if exc:
10✔
1025
            self.exception = BrokenResourceError()
10✔
1026
            self.exception.__cause__ = exc
10✔
1027

1028
        self.read_event.set()
10✔
1029
        self.write_event.set()
10✔
1030

1031
    def data_received(self, data: bytes) -> None:
10✔
1032
        self.read_queue.append(data)
10✔
1033
        self.read_event.set()
10✔
1034

1035
    def eof_received(self) -> bool | None:
10✔
1036
        self.read_event.set()
10✔
1037
        return True
10✔
1038

1039
    def pause_writing(self) -> None:
10✔
1040
        self.write_event = asyncio.Event()
10✔
1041

1042
    def resume_writing(self) -> None:
10✔
UNCOV
1043
        self.write_event.set()
×
1044

1045

1046
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1047
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1048
    read_event: asyncio.Event
10✔
1049
    write_event: asyncio.Event
10✔
1050
    exception: Exception | None = None
10✔
1051

1052
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1053
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1054
        self.read_event = asyncio.Event()
9✔
1055
        self.write_event = asyncio.Event()
9✔
1056
        self.write_event.set()
9✔
1057

1058
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1059
        self.read_event.set()
9✔
1060
        self.write_event.set()
9✔
1061

1062
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1063
        addr = convert_ipv6_sockaddr(addr)
9✔
1064
        self.read_queue.append((data, addr))
9✔
1065
        self.read_event.set()
9✔
1066

1067
    def error_received(self, exc: Exception) -> None:
10✔
UNCOV
1068
        self.exception = exc
×
1069

1070
    def pause_writing(self) -> None:
10✔
UNCOV
1071
        self.write_event.clear()
×
1072

1073
    def resume_writing(self) -> None:
10✔
UNCOV
1074
        self.write_event.set()
×
1075

1076

1077
class SocketStream(abc.SocketStream):
10✔
1078
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1079
        self._transport = transport
10✔
1080
        self._protocol = protocol
10✔
1081
        self._receive_guard = ResourceGuard("reading from")
10✔
1082
        self._send_guard = ResourceGuard("writing to")
10✔
1083
        self._closed = False
10✔
1084

1085
    @property
10✔
1086
    def _raw_socket(self) -> socket.socket:
10✔
1087
        return self._transport.get_extra_info("socket")
10✔
1088

1089
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1090
        with self._receive_guard:
10✔
1091
            await AsyncIOBackend.checkpoint()
10✔
1092

1093
            if (
10✔
1094
                not self._protocol.read_event.is_set()
1095
                and not self._transport.is_closing()
1096
            ):
1097
                self._transport.resume_reading()
10✔
1098
                await self._protocol.read_event.wait()
10✔
1099
                self._transport.pause_reading()
10✔
1100

1101
            try:
10✔
1102
                chunk = self._protocol.read_queue.popleft()
10✔
1103
            except IndexError:
10✔
1104
                if self._closed:
10✔
1105
                    raise ClosedResourceError from None
10✔
1106
                elif self._protocol.exception:
10✔
1107
                    raise self._protocol.exception from None
10✔
1108
                else:
1109
                    raise EndOfStream from None
10✔
1110

1111
            if len(chunk) > max_bytes:
10✔
1112
                # Split the oversized chunk
1113
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1114
                self._protocol.read_queue.appendleft(leftover)
8✔
1115

1116
            # If the read queue is empty, clear the flag so that the next call will
1117
            # block until data is available
1118
            if not self._protocol.read_queue:
10✔
1119
                self._protocol.read_event.clear()
10✔
1120

1121
        return chunk
10✔
1122

1123
    async def send(self, item: bytes) -> None:
10✔
1124
        with self._send_guard:
10✔
1125
            await AsyncIOBackend.checkpoint()
10✔
1126

1127
            if self._closed:
10✔
1128
                raise ClosedResourceError
10✔
1129
            elif self._protocol.exception is not None:
10✔
1130
                raise self._protocol.exception
10✔
1131

1132
            try:
10✔
1133
                self._transport.write(item)
10✔
UNCOV
1134
            except RuntimeError as exc:
×
1135
                if self._transport.is_closing():
×
1136
                    raise BrokenResourceError from exc
×
1137
                else:
UNCOV
1138
                    raise
×
1139

1140
            await self._protocol.write_event.wait()
10✔
1141

1142
    async def send_eof(self) -> None:
10✔
1143
        try:
10✔
1144
            self._transport.write_eof()
10✔
UNCOV
1145
        except OSError:
×
UNCOV
1146
            pass
×
1147

1148
    async def aclose(self) -> None:
10✔
1149
        if not self._transport.is_closing():
10✔
1150
            self._closed = True
10✔
1151
            try:
10✔
1152
                self._transport.write_eof()
10✔
1153
            except OSError:
5✔
1154
                pass
5✔
1155

1156
            self._transport.close()
10✔
1157
            await sleep(0)
10✔
1158
            self._transport.abort()
10✔
1159

1160

1161
class _RawSocketMixin:
10✔
1162
    _receive_future: asyncio.Future | None = None
10✔
1163
    _send_future: asyncio.Future | None = None
10✔
1164
    _closing = False
10✔
1165

1166
    def __init__(self, raw_socket: socket.socket):
10✔
1167
        self.__raw_socket = raw_socket
7✔
1168
        self._receive_guard = ResourceGuard("reading from")
7✔
1169
        self._send_guard = ResourceGuard("writing to")
7✔
1170

1171
    @property
10✔
1172
    def _raw_socket(self) -> socket.socket:
10✔
1173
        return self.__raw_socket
7✔
1174

1175
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1176
        def callback(f: object) -> None:
7✔
1177
            del self._receive_future
7✔
1178
            loop.remove_reader(self.__raw_socket)
7✔
1179

1180
        f = self._receive_future = asyncio.Future()
7✔
1181
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1182
        f.add_done_callback(callback)
7✔
1183
        return f
7✔
1184

1185
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1186
        def callback(f: object) -> None:
7✔
1187
            del self._send_future
7✔
1188
            loop.remove_writer(self.__raw_socket)
7✔
1189

1190
        f = self._send_future = asyncio.Future()
7✔
1191
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1192
        f.add_done_callback(callback)
7✔
1193
        return f
7✔
1194

1195
    async def aclose(self) -> None:
10✔
1196
        if not self._closing:
7✔
1197
            self._closing = True
7✔
1198
            if self.__raw_socket.fileno() != -1:
7✔
1199
                self.__raw_socket.close()
7✔
1200

1201
            if self._receive_future:
7✔
1202
                self._receive_future.set_result(None)
7✔
1203
            if self._send_future:
7✔
UNCOV
1204
                self._send_future.set_result(None)
×
1205

1206

1207
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1208
    async def send_eof(self) -> None:
10✔
1209
        with self._send_guard:
7✔
1210
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1211

1212
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1213
        loop = get_running_loop()
7✔
1214
        await AsyncIOBackend.checkpoint()
7✔
1215
        with self._receive_guard:
7✔
1216
            while True:
4✔
1217
                try:
7✔
1218
                    data = self._raw_socket.recv(max_bytes)
7✔
1219
                except BlockingIOError:
7✔
1220
                    await self._wait_until_readable(loop)
7✔
1221
                except OSError as exc:
7✔
1222
                    if self._closing:
7✔
1223
                        raise ClosedResourceError from None
7✔
1224
                    else:
1225
                        raise BrokenResourceError from exc
1✔
1226
                else:
1227
                    if not data:
7✔
1228
                        raise EndOfStream
7✔
1229

1230
                    return data
7✔
1231

1232
    async def send(self, item: bytes) -> None:
10✔
1233
        loop = get_running_loop()
7✔
1234
        await AsyncIOBackend.checkpoint()
7✔
1235
        with self._send_guard:
7✔
1236
            view = memoryview(item)
7✔
1237
            while view:
7✔
1238
                try:
7✔
1239
                    bytes_sent = self._raw_socket.send(view)
7✔
1240
                except BlockingIOError:
7✔
1241
                    await self._wait_until_writable(loop)
7✔
1242
                except OSError as exc:
7✔
1243
                    if self._closing:
7✔
1244
                        raise ClosedResourceError from None
7✔
1245
                    else:
1246
                        raise BrokenResourceError from exc
1✔
1247
                else:
1248
                    view = view[bytes_sent:]
7✔
1249

1250
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1251
        if not isinstance(msglen, int) or msglen < 0:
7✔
1252
            raise ValueError("msglen must be a non-negative integer")
7✔
1253
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1254
            raise ValueError("maxfds must be a positive integer")
7✔
1255

1256
        loop = get_running_loop()
7✔
1257
        fds = array.array("i")
7✔
1258
        await AsyncIOBackend.checkpoint()
7✔
1259
        with self._receive_guard:
7✔
1260
            while True:
4✔
1261
                try:
7✔
1262
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1263
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1264
                    )
1265
                except BlockingIOError:
7✔
1266
                    await self._wait_until_readable(loop)
7✔
UNCOV
1267
                except OSError as exc:
×
UNCOV
1268
                    if self._closing:
×
UNCOV
1269
                        raise ClosedResourceError from None
×
1270
                    else:
UNCOV
1271
                        raise BrokenResourceError from exc
×
1272
                else:
1273
                    if not message and not ancdata:
7✔
UNCOV
1274
                        raise EndOfStream
×
1275

1276
                    break
4✔
1277

1278
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1279
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
UNCOV
1280
                raise RuntimeError(
×
1281
                    f"Received unexpected ancillary data; message = {message!r}, "
1282
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1283
                )
1284

1285
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1286

1287
        return message, list(fds)
7✔
1288

1289
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1290
        if not message:
7✔
1291
            raise ValueError("message must not be empty")
7✔
1292
        if not fds:
7✔
1293
            raise ValueError("fds must not be empty")
7✔
1294

1295
        loop = get_running_loop()
7✔
1296
        filenos: list[int] = []
7✔
1297
        for fd in fds:
7✔
1298
            if isinstance(fd, int):
7✔
UNCOV
1299
                filenos.append(fd)
×
1300
            elif isinstance(fd, IOBase):
7✔
1301
                filenos.append(fd.fileno())
7✔
1302

1303
        fdarray = array.array("i", filenos)
7✔
1304
        await AsyncIOBackend.checkpoint()
7✔
1305
        with self._send_guard:
7✔
1306
            while True:
4✔
1307
                try:
7✔
1308
                    # The ignore can be removed after mypy picks up
1309
                    # https://github.com/python/typeshed/pull/5545
1310
                    self._raw_socket.sendmsg(
7✔
1311
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1312
                    )
1313
                    break
7✔
UNCOV
1314
                except BlockingIOError:
×
UNCOV
1315
                    await self._wait_until_writable(loop)
×
UNCOV
1316
                except OSError as exc:
×
UNCOV
1317
                    if self._closing:
×
UNCOV
1318
                        raise ClosedResourceError from None
×
1319
                    else:
UNCOV
1320
                        raise BrokenResourceError from exc
×
1321

1322

1323
class TCPSocketListener(abc.SocketListener):
10✔
1324
    _accept_scope: CancelScope | None = None
10✔
1325
    _closed = False
10✔
1326

1327
    def __init__(self, raw_socket: socket.socket):
10✔
1328
        self.__raw_socket = raw_socket
10✔
1329
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1330
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1331

1332
    @property
10✔
1333
    def _raw_socket(self) -> socket.socket:
10✔
1334
        return self.__raw_socket
10✔
1335

1336
    async def accept(self) -> abc.SocketStream:
10✔
1337
        if self._closed:
10✔
1338
            raise ClosedResourceError
10✔
1339

1340
        with self._accept_guard:
10✔
1341
            await AsyncIOBackend.checkpoint()
10✔
1342
            with CancelScope() as self._accept_scope:
10✔
1343
                try:
10✔
1344
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1345
                except asyncio.CancelledError:
9✔
1346
                    # Workaround for https://bugs.python.org/issue41317
1347
                    try:
9✔
1348
                        self._loop.remove_reader(self._raw_socket)
9✔
1349
                    except (ValueError, NotImplementedError):
2✔
1350
                        pass
2✔
1351

1352
                    if self._closed:
9✔
1353
                        raise ClosedResourceError from None
9✔
1354

1355
                    raise
9✔
1356
                finally:
1357
                    self._accept_scope = None
10✔
1358

1359
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1360
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1361
            StreamProtocol, client_sock
1362
        )
1363
        return SocketStream(transport, protocol)
10✔
1364

1365
    async def aclose(self) -> None:
10✔
1366
        if self._closed:
10✔
1367
            return
10✔
1368

1369
        self._closed = True
10✔
1370
        if self._accept_scope:
10✔
1371
            # Workaround for https://bugs.python.org/issue41317
1372
            try:
10✔
1373
                self._loop.remove_reader(self._raw_socket)
10✔
1374
            except (ValueError, NotImplementedError):
2✔
1375
                pass
2✔
1376

1377
            self._accept_scope.cancel()
9✔
1378
            await sleep(0)
9✔
1379

1380
        self._raw_socket.close()
10✔
1381

1382

1383
class UNIXSocketListener(abc.SocketListener):
10✔
1384
    def __init__(self, raw_socket: socket.socket):
10✔
1385
        self.__raw_socket = raw_socket
7✔
1386
        self._loop = get_running_loop()
7✔
1387
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1388
        self._closed = False
7✔
1389

1390
    async def accept(self) -> abc.SocketStream:
10✔
1391
        await AsyncIOBackend.checkpoint()
7✔
1392
        with self._accept_guard:
7✔
1393
            while True:
4✔
1394
                try:
7✔
1395
                    client_sock, _ = self.__raw_socket.accept()
7✔
1396
                    client_sock.setblocking(False)
7✔
1397
                    return UNIXSocketStream(client_sock)
7✔
1398
                except BlockingIOError:
7✔
1399
                    f: asyncio.Future = asyncio.Future()
7✔
1400
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1401
                    f.add_done_callback(
7✔
1402
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1403
                    )
1404
                    await f
7✔
UNCOV
1405
                except OSError as exc:
×
UNCOV
1406
                    if self._closed:
×
UNCOV
1407
                        raise ClosedResourceError from None
×
1408
                    else:
1409
                        raise BrokenResourceError from exc
1✔
1410

1411
    async def aclose(self) -> None:
10✔
1412
        self._closed = True
7✔
1413
        self.__raw_socket.close()
7✔
1414

1415
    @property
10✔
1416
    def _raw_socket(self) -> socket.socket:
10✔
1417
        return self.__raw_socket
7✔
1418

1419

1420
class UDPSocket(abc.UDPSocket):
10✔
1421
    def __init__(
10✔
1422
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1423
    ):
1424
        self._transport = transport
9✔
1425
        self._protocol = protocol
9✔
1426
        self._receive_guard = ResourceGuard("reading from")
9✔
1427
        self._send_guard = ResourceGuard("writing to")
9✔
1428
        self._closed = False
9✔
1429

1430
    @property
10✔
1431
    def _raw_socket(self) -> socket.socket:
10✔
1432
        return self._transport.get_extra_info("socket")
9✔
1433

1434
    async def aclose(self) -> None:
10✔
1435
        if not self._transport.is_closing():
9✔
1436
            self._closed = True
9✔
1437
            self._transport.close()
9✔
1438

1439
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1440
        with self._receive_guard:
9✔
1441
            await AsyncIOBackend.checkpoint()
9✔
1442

1443
            # If the buffer is empty, ask for more data
1444
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1445
                self._protocol.read_event.clear()
9✔
1446
                await self._protocol.read_event.wait()
9✔
1447

1448
            try:
9✔
1449
                return self._protocol.read_queue.popleft()
9✔
1450
            except IndexError:
9✔
1451
                if self._closed:
9✔
1452
                    raise ClosedResourceError from None
9✔
1453
                else:
1454
                    raise BrokenResourceError from None
1✔
1455

1456
    async def send(self, item: UDPPacketType) -> None:
10✔
1457
        with self._send_guard:
9✔
1458
            await AsyncIOBackend.checkpoint()
9✔
1459
            await self._protocol.write_event.wait()
9✔
1460
            if self._closed:
9✔
1461
                raise ClosedResourceError
9✔
1462
            elif self._transport.is_closing():
9✔
UNCOV
1463
                raise BrokenResourceError
×
1464
            else:
1465
                self._transport.sendto(*item)
9✔
1466

1467

1468
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1469
    def __init__(
10✔
1470
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1471
    ):
1472
        self._transport = transport
9✔
1473
        self._protocol = protocol
9✔
1474
        self._receive_guard = ResourceGuard("reading from")
9✔
1475
        self._send_guard = ResourceGuard("writing to")
9✔
1476
        self._closed = False
9✔
1477

1478
    @property
10✔
1479
    def _raw_socket(self) -> socket.socket:
10✔
1480
        return self._transport.get_extra_info("socket")
9✔
1481

1482
    async def aclose(self) -> None:
10✔
1483
        if not self._transport.is_closing():
9✔
1484
            self._closed = True
9✔
1485
            self._transport.close()
9✔
1486

1487
    async def receive(self) -> bytes:
10✔
1488
        with self._receive_guard:
9✔
1489
            await AsyncIOBackend.checkpoint()
9✔
1490

1491
            # If the buffer is empty, ask for more data
1492
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1493
                self._protocol.read_event.clear()
9✔
1494
                await self._protocol.read_event.wait()
9✔
1495

1496
            try:
9✔
1497
                packet = self._protocol.read_queue.popleft()
9✔
1498
            except IndexError:
9✔
1499
                if self._closed:
9✔
1500
                    raise ClosedResourceError from None
9✔
1501
                else:
UNCOV
1502
                    raise BrokenResourceError from None
×
1503

1504
            return packet[0]
9✔
1505

1506
    async def send(self, item: bytes) -> None:
10✔
1507
        with self._send_guard:
9✔
1508
            await AsyncIOBackend.checkpoint()
9✔
1509
            await self._protocol.write_event.wait()
9✔
1510
            if self._closed:
9✔
1511
                raise ClosedResourceError
9✔
1512
            elif self._transport.is_closing():
9✔
UNCOV
1513
                raise BrokenResourceError
×
1514
            else:
1515
                self._transport.sendto(item)
9✔
1516

1517

1518
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1519
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1520
        loop = get_running_loop()
7✔
1521
        await AsyncIOBackend.checkpoint()
7✔
1522
        with self._receive_guard:
7✔
1523
            while True:
4✔
1524
                try:
7✔
1525
                    data = self._raw_socket.recvfrom(65536)
7✔
1526
                except BlockingIOError:
7✔
1527
                    await self._wait_until_readable(loop)
7✔
1528
                except OSError as exc:
7✔
1529
                    if self._closing:
7✔
1530
                        raise ClosedResourceError from None
7✔
1531
                    else:
1532
                        raise BrokenResourceError from exc
1✔
1533
                else:
1534
                    return data
7✔
1535

1536
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1537
        loop = get_running_loop()
7✔
1538
        await AsyncIOBackend.checkpoint()
7✔
1539
        with self._send_guard:
7✔
1540
            while True:
4✔
1541
                try:
7✔
1542
                    self._raw_socket.sendto(*item)
7✔
1543
                except BlockingIOError:
7✔
UNCOV
1544
                    await self._wait_until_writable(loop)
×
1545
                except OSError as exc:
7✔
1546
                    if self._closing:
7✔
1547
                        raise ClosedResourceError from None
7✔
1548
                    else:
1549
                        raise BrokenResourceError from exc
1✔
1550
                else:
1551
                    return
7✔
1552

1553

1554
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1555
    async def receive(self) -> bytes:
10✔
1556
        loop = get_running_loop()
7✔
1557
        await AsyncIOBackend.checkpoint()
7✔
1558
        with self._receive_guard:
7✔
1559
            while True:
4✔
1560
                try:
7✔
1561
                    data = self._raw_socket.recv(65536)
7✔
1562
                except BlockingIOError:
7✔
1563
                    await self._wait_until_readable(loop)
7✔
1564
                except OSError as exc:
7✔
1565
                    if self._closing:
7✔
1566
                        raise ClosedResourceError from None
7✔
1567
                    else:
1568
                        raise BrokenResourceError from exc
1✔
1569
                else:
1570
                    return data
7✔
1571

1572
    async def send(self, item: bytes) -> None:
10✔
1573
        loop = get_running_loop()
7✔
1574
        await AsyncIOBackend.checkpoint()
7✔
1575
        with self._send_guard:
7✔
1576
            while True:
4✔
1577
                try:
7✔
1578
                    self._raw_socket.send(item)
7✔
1579
                except BlockingIOError:
7✔
UNCOV
1580
                    await self._wait_until_writable(loop)
×
1581
                except OSError as exc:
7✔
1582
                    if self._closing:
7✔
1583
                        raise ClosedResourceError from None
7✔
1584
                    else:
1585
                        raise BrokenResourceError from exc
1✔
1586
                else:
1587
                    return
7✔
1588

1589

1590
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1591
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1592

1593

1594
#
1595
# Synchronization
1596
#
1597

1598

1599
class Event(BaseEvent):
10✔
1600
    def __new__(cls) -> Event:
10✔
1601
        return object.__new__(cls)
10✔
1602

1603
    def __init__(self) -> None:
10✔
1604
        self._event = asyncio.Event()
10✔
1605

1606
    def set(self) -> None:
10✔
1607
        self._event.set()
10✔
1608

1609
    def is_set(self) -> bool:
10✔
1610
        return self._event.is_set()
9✔
1611

1612
    async def wait(self) -> None:
10✔
1613
        if await self._event.wait():
10✔
1614
            await AsyncIOBackend.checkpoint()
10✔
1615

1616
    def statistics(self) -> EventStatistics:
10✔
1617
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1618

1619

1620
class CapacityLimiter(BaseCapacityLimiter):
10✔
1621
    _total_tokens: float = 0
10✔
1622

1623
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1624
        return object.__new__(cls)
10✔
1625

1626
    def __init__(self, total_tokens: float):
10✔
1627
        self._borrowers: set[Any] = set()
10✔
1628
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1629
        self.total_tokens = total_tokens
10✔
1630

1631
    async def __aenter__(self) -> None:
10✔
1632
        await self.acquire()
10✔
1633

1634
    async def __aexit__(
10✔
1635
        self,
1636
        exc_type: type[BaseException] | None,
1637
        exc_val: BaseException | None,
1638
        exc_tb: TracebackType | None,
1639
    ) -> None:
1640
        self.release()
10✔
1641

1642
    @property
10✔
1643
    def total_tokens(self) -> float:
10✔
1644
        return self._total_tokens
9✔
1645

1646
    @total_tokens.setter
10✔
1647
    def total_tokens(self, value: float) -> None:
10✔
1648
        if not isinstance(value, int) and not math.isinf(value):
10✔
1649
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1650
        if value < 1:
10✔
1651
            raise ValueError("total_tokens must be >= 1")
9✔
1652

1653
        old_value = self._total_tokens
10✔
1654
        self._total_tokens = value
10✔
1655
        events = []
10✔
1656
        for event in self._wait_queue.values():
10✔
1657
            if value <= old_value:
9✔
UNCOV
1658
                break
×
1659

1660
            if not event.is_set():
9✔
1661
                events.append(event)
9✔
1662
                old_value += 1
9✔
1663

1664
        for event in events:
10✔
1665
            event.set()
9✔
1666

1667
    @property
10✔
1668
    def borrowed_tokens(self) -> int:
10✔
1669
        return len(self._borrowers)
9✔
1670

1671
    @property
10✔
1672
    def available_tokens(self) -> float:
10✔
1673
        return self._total_tokens - len(self._borrowers)
9✔
1674

1675
    def acquire_nowait(self) -> None:
10✔
UNCOV
1676
        self.acquire_on_behalf_of_nowait(current_task())
×
1677

1678
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1679
        if borrower in self._borrowers:
10✔
1680
            raise RuntimeError(
9✔
1681
                "this borrower is already holding one of this CapacityLimiter's "
1682
                "tokens"
1683
            )
1684

1685
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1686
            raise WouldBlock
9✔
1687

1688
        self._borrowers.add(borrower)
10✔
1689

1690
    async def acquire(self) -> None:
10✔
1691
        return await self.acquire_on_behalf_of(current_task())
10✔
1692

1693
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1694
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1695
        try:
10✔
1696
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1697
        except WouldBlock:
9✔
1698
            event = asyncio.Event()
9✔
1699
            self._wait_queue[borrower] = event
9✔
1700
            try:
9✔
1701
                await event.wait()
9✔
UNCOV
1702
            except BaseException:
×
UNCOV
1703
                self._wait_queue.pop(borrower, None)
×
UNCOV
1704
                raise
×
1705

1706
            self._borrowers.add(borrower)
9✔
1707
        else:
1708
            try:
10✔
1709
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1710
            except BaseException:
9✔
1711
                self.release()
9✔
1712
                raise
9✔
1713

1714
    def release(self) -> None:
10✔
1715
        self.release_on_behalf_of(current_task())
10✔
1716

1717
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1718
        try:
10✔
1719
            self._borrowers.remove(borrower)
10✔
1720
        except KeyError:
9✔
1721
            raise RuntimeError(
9✔
1722
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1723
            ) from None
1724

1725
        # Notify the next task in line if this limiter has free capacity now
1726
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1727
            event = self._wait_queue.popitem(last=False)[1]
9✔
1728
            event.set()
9✔
1729

1730
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1731
        return CapacityLimiterStatistics(
9✔
1732
            self.borrowed_tokens,
1733
            self.total_tokens,
1734
            tuple(self._borrowers),
1735
            len(self._wait_queue),
1736
        )
1737

1738

1739
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1740

1741

1742
#
1743
# Operating system signals
1744
#
1745

1746

1747
class _SignalReceiver:
10✔
1748
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1749
        self._signals = signals
8✔
1750
        self._loop = get_running_loop()
8✔
1751
        self._signal_queue: deque[Signals] = deque()
8✔
1752
        self._future: asyncio.Future = asyncio.Future()
8✔
1753
        self._handled_signals: set[Signals] = set()
8✔
1754

1755
    def _deliver(self, signum: Signals) -> None:
10✔
1756
        self._signal_queue.append(signum)
8✔
1757
        if not self._future.done():
8✔
1758
            self._future.set_result(None)
8✔
1759

1760
    def __enter__(self) -> _SignalReceiver:
10✔
1761
        for sig in set(self._signals):
8✔
1762
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1763
            self._handled_signals.add(sig)
8✔
1764

1765
        return self
8✔
1766

1767
    def __exit__(
10✔
1768
        self,
1769
        exc_type: type[BaseException] | None,
1770
        exc_val: BaseException | None,
1771
        exc_tb: TracebackType | None,
1772
    ) -> bool | None:
1773
        for sig in self._handled_signals:
8✔
1774
            self._loop.remove_signal_handler(sig)
8✔
1775
        return None
8✔
1776

1777
    def __aiter__(self) -> _SignalReceiver:
10✔
1778
        return self
8✔
1779

1780
    async def __anext__(self) -> Signals:
10✔
1781
        await AsyncIOBackend.checkpoint()
8✔
1782
        if not self._signal_queue:
8✔
UNCOV
1783
            self._future = asyncio.Future()
×
UNCOV
1784
            await self._future
×
1785

1786
        return self._signal_queue.popleft()
8✔
1787

1788

1789
#
1790
# Testing and debugging
1791
#
1792

1793

1794
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1795
    task_state = _task_states.get(task)
10✔
1796
    if task_state is None:
10✔
1797
        parent_id = None
10✔
1798
    else:
1799
        parent_id = task_state.parent_id
10✔
1800

1801
    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
10✔
1802

1803

1804
class TestRunner(abc.TestRunner):
10✔
1805
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1806

1807
    def __init__(
10✔
1808
        self,
1809
        *,
1810
        debug: bool | None = None,
1811
        use_uvloop: bool = False,
1812
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
1813
    ) -> None:
1814
        if use_uvloop and loop_factory is None:
10✔
UNCOV
1815
            import uvloop
×
1816

UNCOV
1817
            loop_factory = uvloop.new_event_loop
×
1818

1819
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
1820
        self._exceptions: list[BaseException] = []
10✔
1821
        self._runner_task: asyncio.Task | None = None
10✔
1822

1823
    def __enter__(self) -> TestRunner:
10✔
1824
        self._runner.__enter__()
10✔
1825
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
1826
        return self
10✔
1827

1828
    def __exit__(
10✔
1829
        self,
1830
        exc_type: type[BaseException] | None,
1831
        exc_val: BaseException | None,
1832
        exc_tb: TracebackType | None,
1833
    ) -> None:
1834
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
1835

1836
    def get_loop(self) -> AbstractEventLoop:
10✔
1837
        return self._runner.get_loop()
10✔
1838

1839
    def _exception_handler(
10✔
1840
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1841
    ) -> None:
1842
        if isinstance(context.get("exception"), Exception):
10✔
1843
            self._exceptions.append(context["exception"])
10✔
1844
        else:
1845
            loop.default_exception_handler(context)
10✔
1846

1847
    def _raise_async_exceptions(self) -> None:
10✔
1848
        # Re-raise any exceptions raised in asynchronous callbacks
1849
        if self._exceptions:
10✔
1850
            exceptions, self._exceptions = self._exceptions, []
10✔
1851
            if len(exceptions) == 1:
10✔
1852
                raise exceptions[0]
10✔
UNCOV
1853
            elif exceptions:
×
UNCOV
1854
                raise BaseExceptionGroup(
×
1855
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1856
                )
1857

1858
    @staticmethod
10✔
1859
    async def _run_tests_and_fixtures(
10✔
1860
        receive_stream: MemoryObjectReceiveStream[
1861
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
1862
        ],
1863
    ) -> None:
1864
        with receive_stream:
10✔
1865
            async for coro, future in receive_stream:
10✔
1866
                try:
10✔
1867
                    retval = await coro
10✔
1868
                except BaseException as exc:
10✔
1869
                    if not future.cancelled():
10✔
1870
                        future.set_exception(exc)
10✔
1871
                else:
1872
                    if not future.cancelled():
10✔
1873
                        future.set_result(retval)
10✔
1874

1875
    async def _call_in_runner_task(
10✔
1876
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1877
    ) -> T_Retval:
1878
        if not self._runner_task:
10✔
1879
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
1880
                Tuple[Awaitable[Any], asyncio.Future]
1881
            ](1)
1882
            self._runner_task = self.get_loop().create_task(
10✔
1883
                self._run_tests_and_fixtures(receive_stream)
1884
            )
1885

1886
        coro = func(*args, **kwargs)
10✔
1887
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
1888
        self._send_stream.send_nowait((coro, future))
10✔
1889
        return await future
10✔
1890

1891
    def run_asyncgen_fixture(
10✔
1892
        self,
1893
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1894
        kwargs: dict[str, Any],
1895
    ) -> Iterable[T_Retval]:
1896
        asyncgen = fixture_func(**kwargs)
10✔
1897
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
1898
            self._call_in_runner_task(asyncgen.asend, None)
1899
        )
1900
        self._raise_async_exceptions()
10✔
1901

1902
        yield fixturevalue
10✔
1903

1904
        try:
10✔
1905
            self.get_loop().run_until_complete(
10✔
1906
                self._call_in_runner_task(asyncgen.asend, None)
1907
            )
1908
        except StopAsyncIteration:
10✔
1909
            self._raise_async_exceptions()
10✔
1910
        else:
UNCOV
1911
            self.get_loop().run_until_complete(asyncgen.aclose())
×
UNCOV
1912
            raise RuntimeError("Async generator fixture did not stop")
×
1913

1914
    def run_fixture(
10✔
1915
        self,
1916
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1917
        kwargs: dict[str, Any],
1918
    ) -> T_Retval:
1919
        retval = self.get_loop().run_until_complete(
10✔
1920
            self._call_in_runner_task(fixture_func, **kwargs)
1921
        )
1922
        self._raise_async_exceptions()
10✔
1923
        return retval
10✔
1924

1925
    def run_test(
10✔
1926
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1927
    ) -> None:
1928
        try:
10✔
1929
            self.get_loop().run_until_complete(
10✔
1930
                self._call_in_runner_task(test_func, **kwargs)
1931
            )
1932
        except Exception as exc:
10✔
1933
            self._exceptions.append(exc)
9✔
1934

1935
        self._raise_async_exceptions()
10✔
1936

1937

1938
class AsyncIOBackend(AsyncBackend):
10✔
1939
    @classmethod
10✔
1940
    def run(
10✔
1941
        cls,
1942
        func: Callable[..., Awaitable[T_Retval]],
1943
        args: tuple,
1944
        kwargs: dict[str, Any],
1945
        options: dict[str, Any],
1946
    ) -> T_Retval:
1947
        @wraps(func)
10✔
1948
        async def wrapper() -> T_Retval:
10✔
1949
            task = cast(asyncio.Task, current_task())
10✔
1950
            task.set_name(get_callable_name(func))
10✔
1951
            _task_states[task] = TaskState(None, None)
10✔
1952

1953
            try:
10✔
1954
                return await func(*args)
10✔
1955
            finally:
1956
                del _task_states[task]
10✔
1957

1958
        debug = options.get("debug", False)
10✔
1959
        options.get("loop_factory", None)
10✔
1960
        options.get("use_uvloop", False)
10✔
1961
        return native_run(wrapper(), debug=debug)
10✔
1962

1963
    @classmethod
10✔
1964
    def current_token(cls) -> object:
10✔
1965
        return get_running_loop()
10✔
1966

1967
    @classmethod
10✔
1968
    def current_time(cls) -> float:
10✔
1969
        return get_running_loop().time()
10✔
1970

1971
    @classmethod
10✔
1972
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
1973
        return CancelledError
10✔
1974

1975
    @classmethod
10✔
1976
    async def checkpoint(cls) -> None:
10✔
1977
        await sleep(0)
10✔
1978

1979
    @classmethod
10✔
1980
    async def checkpoint_if_cancelled(cls) -> None:
10✔
1981
        task = current_task()
10✔
1982
        if task is None:
10✔
UNCOV
1983
            return
×
1984

1985
        try:
10✔
1986
            cancel_scope = _task_states[task].cancel_scope
10✔
1987
        except KeyError:
10✔
1988
            return
10✔
1989

1990
        while cancel_scope:
10✔
1991
            if cancel_scope.cancel_called:
10✔
1992
                await sleep(0)
10✔
1993
            elif cancel_scope.shield:
10✔
1994
                break
9✔
1995
            else:
1996
                cancel_scope = cancel_scope._parent_scope
10✔
1997

1998
    @classmethod
10✔
1999
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
2000
        with CancelScope(shield=True):
10✔
2001
            await sleep(0)
10✔
2002

2003
    @classmethod
10✔
2004
    async def sleep(cls, delay: float) -> None:
10✔
2005
        await sleep(delay)
10✔
2006

2007
    @classmethod
10✔
2008
    def create_cancel_scope(
10✔
2009
        cls, *, deadline: float = math.inf, shield: bool = False
2010
    ) -> CancelScope:
2011
        return CancelScope(deadline=deadline, shield=shield)
10✔
2012

2013
    @classmethod
10✔
2014
    def current_effective_deadline(cls) -> float:
10✔
2015
        try:
9✔
2016
            cancel_scope = _task_states[
9✔
2017
                current_task()  # type: ignore[index]
2018
            ].cancel_scope
UNCOV
2019
        except KeyError:
×
UNCOV
2020
            return math.inf
×
2021

2022
        deadline = math.inf
9✔
2023
        while cancel_scope:
9✔
2024
            deadline = min(deadline, cancel_scope.deadline)
9✔
2025
            if cancel_scope._cancel_called:
9✔
2026
                deadline = -math.inf
9✔
2027
                break
9✔
2028
            elif cancel_scope.shield:
9✔
2029
                break
9✔
2030
            else:
2031
                cancel_scope = cancel_scope._parent_scope
9✔
2032

2033
        return deadline
9✔
2034

2035
    @classmethod
10✔
2036
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2037
        return TaskGroup()
10✔
2038

2039
    @classmethod
10✔
2040
    def create_event(cls) -> abc.Event:
10✔
2041
        return Event()
10✔
2042

2043
    @classmethod
10✔
2044
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2045
        return CapacityLimiter(total_tokens)
9✔
2046

2047
    @classmethod
10✔
2048
    async def run_sync_in_worker_thread(
10✔
2049
        cls,
2050
        func: Callable[..., T_Retval],
2051
        args: tuple[Any, ...],
2052
        cancellable: bool = False,
2053
        limiter: abc.CapacityLimiter | None = None,
2054
    ) -> T_Retval:
2055
        await cls.checkpoint()
10✔
2056

2057
        # If this is the first run in this event loop thread, set up the necessary
2058
        # variables
2059
        try:
10✔
2060
            idle_workers = _threadpool_idle_workers.get()
10✔
2061
            workers = _threadpool_workers.get()
10✔
2062
        except LookupError:
10✔
2063
            idle_workers = deque()
10✔
2064
            workers = set()
10✔
2065
            _threadpool_idle_workers.set(idle_workers)
10✔
2066
            _threadpool_workers.set(workers)
10✔
2067

2068
        async with (limiter or cls.current_default_thread_limiter()):
10✔
2069
            with CancelScope(shield=not cancellable):
10✔
2070
                future: asyncio.Future = asyncio.Future()
10✔
2071
                root_task = find_root_task()
10✔
2072
                if not idle_workers:
10✔
2073
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2074
                    worker.start()
10✔
2075
                    workers.add(worker)
10✔
2076
                    root_task.add_done_callback(worker.stop)
10✔
2077
                else:
2078
                    worker = idle_workers.pop()
10✔
2079

2080
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2081
                    # seconds or longer
2082
                    now = cls.current_time()
10✔
2083
                    while idle_workers:
10✔
2084
                        if (
8✔
2085
                            now - idle_workers[0].idle_since
2086
                            < WorkerThread.MAX_IDLE_TIME
2087
                        ):
2088
                            break
8✔
2089

UNCOV
2090
                        expired_worker = idle_workers.popleft()
×
UNCOV
2091
                        expired_worker.root_task.remove_done_callback(
×
2092
                            expired_worker.stop
2093
                        )
UNCOV
2094
                        expired_worker.stop()
×
2095

2096
                context = copy_context()
10✔
2097
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2098
                worker.queue.put_nowait((context, func, args, future))
10✔
2099
                return await future
10✔
2100

2101
    @classmethod
10✔
2102
    def run_async_from_thread(
10✔
2103
        cls,
2104
        func: Callable[..., Awaitable[T_Retval]],
2105
        args: tuple[Any, ...],
2106
        token: object,
2107
    ) -> T_Retval:
2108
        loop = cast(AbstractEventLoop, token)
10✔
2109
        context = copy_context()
10✔
2110
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2111
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2112
            asyncio.run_coroutine_threadsafe, func(*args), loop
2113
        )
2114
        return f.result()
10✔
2115

2116
    @classmethod
10✔
2117
    def run_sync_from_thread(
10✔
2118
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2119
    ) -> T_Retval:
2120
        @wraps(func)
10✔
2121
        def wrapper() -> None:
10✔
2122
            try:
10✔
2123
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2124
                f.set_result(func(*args))
10✔
2125
            except BaseException as exc:
10✔
2126
                f.set_exception(exc)
10✔
2127
                if not isinstance(exc, Exception):
10✔
UNCOV
2128
                    raise
×
2129

2130
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2131
        loop = cast(AbstractEventLoop, token)
10✔
2132
        loop.call_soon_threadsafe(wrapper)
10✔
2133
        return f.result()
10✔
2134

2135
    @classmethod
10✔
2136
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2137
        return BlockingPortal()
10✔
2138

2139
    @classmethod
10✔
2140
    async def open_process(
10✔
2141
        cls,
2142
        command: str | bytes | Sequence[str | bytes],
2143
        *,
2144
        shell: bool,
2145
        stdin: int | IO[Any] | None,
2146
        stdout: int | IO[Any] | None,
2147
        stderr: int | IO[Any] | None,
2148
        cwd: str | bytes | PathLike | None = None,
2149
        env: Mapping[str, str] | None = None,
2150
        start_new_session: bool = False,
2151
    ) -> Process:
2152
        await cls.checkpoint()
9✔
2153
        if shell:
9✔
2154
            process = await asyncio.create_subprocess_shell(
9✔
2155
                cast("str | bytes", command),
2156
                stdin=stdin,
2157
                stdout=stdout,
2158
                stderr=stderr,
2159
                cwd=cwd,
2160
                env=env,
2161
                start_new_session=start_new_session,
2162
            )
2163
        else:
2164
            process = await asyncio.create_subprocess_exec(
9✔
2165
                *command,
2166
                stdin=stdin,
2167
                stdout=stdout,
2168
                stderr=stderr,
2169
                cwd=cwd,
2170
                env=env,
2171
                start_new_session=start_new_session,
2172
            )
2173

2174
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2175
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2176
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2177
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2178

2179
    @classmethod
10✔
2180
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2181
        create_task(
9✔
2182
            _shutdown_process_pool_on_exit(workers),
2183
            name="AnyIO process pool shutdown task",
2184
        )
2185
        find_root_task().add_done_callback(
9✔
2186
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2187
        )
2188

2189
    @classmethod
10✔
2190
    async def connect_tcp(
10✔
2191
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2192
    ) -> abc.SocketStream:
2193
        transport, protocol = cast(
10✔
2194
            Tuple[asyncio.Transport, StreamProtocol],
2195
            await get_running_loop().create_connection(
2196
                StreamProtocol, host, port, local_addr=local_address
2197
            ),
2198
        )
2199
        transport.pause_reading()
10✔
2200
        return SocketStream(transport, protocol)
10✔
2201

2202
    @classmethod
10✔
2203
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
10✔
2204
        await cls.checkpoint()
7✔
2205
        loop = get_running_loop()
7✔
2206
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2207
        raw_socket.setblocking(False)
7✔
2208
        while True:
4✔
2209
            try:
7✔
2210
                raw_socket.connect(path)
7✔
2211
            except BlockingIOError:
7✔
UNCOV
2212
                f: asyncio.Future = asyncio.Future()
×
UNCOV
2213
                loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2214
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2215
                await f
×
2216
            except BaseException:
7✔
2217
                raw_socket.close()
7✔
2218
                raise
7✔
2219
            else:
2220
                return UNIXSocketStream(raw_socket)
7✔
2221

2222
    @classmethod
10✔
2223
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2224
        return TCPSocketListener(sock)
10✔
2225

2226
    @classmethod
10✔
2227
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2228
        return UNIXSocketListener(sock)
7✔
2229

2230
    @classmethod
10✔
2231
    async def create_udp_socket(
10✔
2232
        cls,
2233
        family: AddressFamily,
2234
        local_address: IPSockAddrType | None,
2235
        remote_address: IPSockAddrType | None,
2236
        reuse_port: bool,
2237
    ) -> UDPSocket | ConnectedUDPSocket:
2238
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2239
            DatagramProtocol,
2240
            local_addr=local_address,
2241
            remote_addr=remote_address,
2242
            family=family,
2243
            reuse_port=reuse_port,
2244
        )
2245
        if protocol.exception:
9✔
UNCOV
2246
            transport.close()
×
UNCOV
2247
            raise protocol.exception
×
2248

2249
        if not remote_address:
9✔
2250
            return UDPSocket(transport, protocol)
9✔
2251
        else:
2252
            return ConnectedUDPSocket(transport, protocol)
9✔
2253

2254
    @classmethod
10✔
2255
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2256
        cls, raw_socket: socket.socket, remote_path: str | None
2257
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2258
        await cls.checkpoint()
7✔
2259
        loop = get_running_loop()
7✔
2260

2261
        if remote_path:
7✔
2262
            while True:
4✔
2263
                try:
7✔
2264
                    raw_socket.connect(remote_path)
7✔
UNCOV
2265
                except BlockingIOError:
×
UNCOV
2266
                    f: asyncio.Future = asyncio.Future()
×
UNCOV
2267
                    loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2268
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2269
                    await f
×
UNCOV
2270
                except BaseException:
×
UNCOV
2271
                    raw_socket.close()
×
UNCOV
2272
                    raise
×
2273
                else:
2274
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2275
        else:
2276
            return UNIXDatagramSocket(raw_socket)
7✔
2277

2278
    @classmethod
10✔
2279
    async def getaddrinfo(
10✔
2280
        cls,
2281
        host: bytes | str | None,
2282
        port: str | int | None,
2283
        *,
2284
        family: int | AddressFamily = 0,
2285
        type: int | SocketKind = 0,
2286
        proto: int = 0,
2287
        flags: int = 0,
2288
    ) -> list[
2289
        tuple[
2290
            AddressFamily,
2291
            SocketKind,
2292
            int,
2293
            str,
2294
            tuple[str, int] | tuple[str, int, int, int],
2295
        ]
2296
    ]:
2297
        return await get_running_loop().getaddrinfo(
10✔
2298
            host, port, family=family, type=type, proto=proto, flags=flags
2299
        )
2300

2301
    @classmethod
10✔
2302
    async def getnameinfo(
10✔
2303
        cls, sockaddr: IPSockAddrType, flags: int = 0
2304
    ) -> tuple[str, str]:
2305
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2306

2307
    @classmethod
10✔
2308
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2309
        await cls.checkpoint()
×
2310
        try:
×
2311
            read_events = _read_events.get()
×
2312
        except LookupError:
×
2313
            read_events = {}
×
UNCOV
2314
            _read_events.set(read_events)
×
2315

2316
        if read_events.get(sock):
×
2317
            raise BusyResourceError("reading from") from None
×
2318

2319
        loop = get_running_loop()
×
UNCOV
2320
        event = read_events[sock] = asyncio.Event()
×
2321
        loop.add_reader(sock, event.set)
×
2322
        try:
×
UNCOV
2323
            await event.wait()
×
2324
        finally:
UNCOV
2325
            if read_events.pop(sock, None) is not None:
×
2326
                loop.remove_reader(sock)
×
2327
                readable = True
×
2328
            else:
2329
                readable = False
×
2330

2331
        if not readable:
×
UNCOV
2332
            raise ClosedResourceError
×
2333

2334
    @classmethod
10✔
2335
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2336
        await cls.checkpoint()
×
2337
        try:
×
2338
            write_events = _write_events.get()
×
2339
        except LookupError:
×
2340
            write_events = {}
×
UNCOV
2341
            _write_events.set(write_events)
×
2342

2343
        if write_events.get(sock):
×
2344
            raise BusyResourceError("writing to") from None
×
2345

2346
        loop = get_running_loop()
×
UNCOV
2347
        event = write_events[sock] = asyncio.Event()
×
2348
        loop.add_writer(sock.fileno(), event.set)
×
2349
        try:
×
UNCOV
2350
            await event.wait()
×
2351
        finally:
UNCOV
2352
            if write_events.pop(sock, None) is not None:
×
UNCOV
2353
                loop.remove_writer(sock)
×
UNCOV
2354
                writable = True
×
2355
            else:
UNCOV
2356
                writable = False
×
2357

UNCOV
2358
        if not writable:
×
UNCOV
2359
            raise ClosedResourceError
×
2360

2361
    @classmethod
10✔
2362
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2363
        try:
10✔
2364
            return _default_thread_limiter.get()
10✔
2365
        except LookupError:
10✔
2366
            limiter = CapacityLimiter(40)
10✔
2367
            _default_thread_limiter.set(limiter)
10✔
2368
            return limiter
10✔
2369

2370
    @classmethod
10✔
2371
    def open_signal_receiver(
10✔
2372
        cls, *signals: Signals
2373
    ) -> ContextManager[AsyncIterator[Signals]]:
2374
        return _SignalReceiver(signals)
8✔
2375

2376
    @classmethod
10✔
2377
    def get_current_task(cls) -> TaskInfo:
10✔
2378
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2379

2380
    @classmethod
10✔
2381
    def get_running_tasks(cls) -> list[TaskInfo]:
10✔
2382
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2383

2384
    @classmethod
10✔
2385
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2386
        await cls.checkpoint()
10✔
2387
        this_task = current_task()
10✔
2388
        while True:
6✔
2389
            for task in all_tasks():
10✔
2390
                if task is this_task:
10✔
2391
                    continue
10✔
2392

2393
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2394
                if waiter is None or waiter.done():
10✔
2395
                    await sleep(0.1)
10✔
2396
                    break
10✔
2397
            else:
2398
                return
10✔
2399

2400
    @classmethod
10✔
2401
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2402
        return TestRunner(**options)
10✔
2403

2404

2405
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc