• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 10653647235

01 Sep 2024 11:34AM UTC coverage: 91.671% (-0.1%) from 91.768%
10653647235

Pull #761

github

web-flow
Merge 787a4544d into bc962eff7
Pull Request #761: Delegated the implementations of Lock and Semaphore to the async backend class

229 of 250 new or added lines in 4 files covered. (91.6%)

2 existing lines in 2 files now uncovered.

4744 of 5175 relevant lines covered (91.67%)

9.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.72
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
11✔
2

3
import array
11✔
4
import asyncio
11✔
5
import concurrent.futures
11✔
6
import math
11✔
7
import socket
11✔
8
import sys
11✔
9
import threading
11✔
10
import weakref
11✔
11
from asyncio import (
11✔
12
    AbstractEventLoop,
13
    CancelledError,
14
    all_tasks,
15
    create_task,
16
    current_task,
17
    get_running_loop,
18
    sleep,
19
)
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
11✔
21
from collections import OrderedDict, deque
11✔
22
from collections.abc import AsyncIterator, Iterable
11✔
23
from concurrent.futures import Future
11✔
24
from contextlib import suppress
11✔
25
from contextvars import Context, copy_context
11✔
26
from dataclasses import dataclass
11✔
27
from functools import partial, wraps
11✔
28
from inspect import (
11✔
29
    CORO_RUNNING,
30
    CORO_SUSPENDED,
31
    getcoroutinestate,
32
    iscoroutine,
33
)
34
from io import IOBase
11✔
35
from os import PathLike
11✔
36
from queue import Queue
11✔
37
from signal import Signals
11✔
38
from socket import AddressFamily, SocketKind
11✔
39
from threading import Thread
11✔
40
from types import TracebackType
11✔
41
from typing import (
11✔
42
    IO,
43
    Any,
44
    AsyncGenerator,
45
    Awaitable,
46
    Callable,
47
    Collection,
48
    ContextManager,
49
    Coroutine,
50
    Mapping,
51
    Optional,
52
    Sequence,
53
    Tuple,
54
    TypeVar,
55
    cast,
56
)
57
from weakref import WeakKeyDictionary
11✔
58

59
import sniffio
11✔
60

61
from .. import (
11✔
62
    CapacityLimiterStatistics,
63
    EventStatistics,
64
    LockStatistics,
65
    TaskInfo,
66
    abc,
67
)
68
from .._core._eventloop import claim_worker_thread, threadlocals
11✔
69
from .._core._exceptions import (
11✔
70
    BrokenResourceError,
71
    BusyResourceError,
72
    ClosedResourceError,
73
    EndOfStream,
74
    WouldBlock,
75
    iterate_exceptions,
76
)
77
from .._core._sockets import convert_ipv6_sockaddr
11✔
78
from .._core._streams import create_memory_object_stream
11✔
79
from .._core._synchronization import (
11✔
80
    CapacityLimiter as BaseCapacityLimiter,
81
)
82
from .._core._synchronization import Event as BaseEvent
11✔
83
from .._core._synchronization import Lock as BaseLock
11✔
84
from .._core._synchronization import (
11✔
85
    ResourceGuard,
86
    SemaphoreStatistics,
87
)
88
from .._core._synchronization import Semaphore as BaseSemaphore
11✔
89
from .._core._tasks import CancelScope as BaseCancelScope
11✔
90
from ..abc import (
11✔
91
    AsyncBackend,
92
    IPSockAddrType,
93
    SocketListener,
94
    UDPPacketType,
95
    UNIXDatagramPacketType,
96
)
97
from ..lowlevel import RunVar
11✔
98
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11✔
99

100
if sys.version_info >= (3, 10):
11✔
101
    from typing import ParamSpec
7✔
102
else:
103
    from typing_extensions import ParamSpec
4✔
104

105
if sys.version_info >= (3, 11):
11✔
106
    from asyncio import Runner
5✔
107
    from typing import TypeVarTuple, Unpack
5✔
108
else:
109
    import contextvars
6✔
110
    import enum
6✔
111
    import signal
6✔
112
    from asyncio import coroutines, events, exceptions, tasks
6✔
113

114
    from exceptiongroup import BaseExceptionGroup
6✔
115
    from typing_extensions import TypeVarTuple, Unpack
6✔
116

117
    class _State(enum.Enum):
6✔
118
        CREATED = "created"
6✔
119
        INITIALIZED = "initialized"
6✔
120
        CLOSED = "closed"
6✔
121

122
    class Runner:
6✔
123
        # Copied from CPython 3.11
124
        def __init__(
6✔
125
            self,
126
            *,
127
            debug: bool | None = None,
128
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
129
        ):
130
            self._state = _State.CREATED
6✔
131
            self._debug = debug
6✔
132
            self._loop_factory = loop_factory
6✔
133
            self._loop: AbstractEventLoop | None = None
6✔
134
            self._context = None
6✔
135
            self._interrupt_count = 0
6✔
136
            self._set_event_loop = False
6✔
137

138
        def __enter__(self) -> Runner:
6✔
139
            self._lazy_init()
6✔
140
            return self
6✔
141

142
        def __exit__(
6✔
143
            self,
144
            exc_type: type[BaseException],
145
            exc_val: BaseException,
146
            exc_tb: TracebackType,
147
        ) -> None:
148
            self.close()
6✔
149

150
        def close(self) -> None:
6✔
151
            """Shutdown and close event loop."""
152
            if self._state is not _State.INITIALIZED:
6✔
153
                return
×
154
            try:
6✔
155
                loop = self._loop
6✔
156
                _cancel_all_tasks(loop)
6✔
157
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
158
                if hasattr(loop, "shutdown_default_executor"):
6✔
159
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
160
                else:
161
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
162
            finally:
163
                if self._set_event_loop:
6✔
164
                    events.set_event_loop(None)
6✔
165
                loop.close()
6✔
166
                self._loop = None
6✔
167
                self._state = _State.CLOSED
6✔
168

169
        def get_loop(self) -> AbstractEventLoop:
6✔
170
            """Return embedded event loop."""
171
            self._lazy_init()
6✔
172
            return self._loop
6✔
173

174
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
175
            """Run a coroutine inside the embedded event loop."""
176
            if not coroutines.iscoroutine(coro):
6✔
177
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
178

179
            if events._get_running_loop() is not None:
6✔
180
                # fail fast with short traceback
181
                raise RuntimeError(
×
182
                    "Runner.run() cannot be called from a running event loop"
183
                )
184

185
            self._lazy_init()
6✔
186

187
            if context is None:
6✔
188
                context = self._context
6✔
189
            task = context.run(self._loop.create_task, coro)
6✔
190

191
            if (
6✔
192
                threading.current_thread() is threading.main_thread()
193
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
194
            ):
195
                sigint_handler = partial(self._on_sigint, main_task=task)
6✔
196
                try:
6✔
197
                    signal.signal(signal.SIGINT, sigint_handler)
6✔
198
                except ValueError:
×
199
                    # `signal.signal` may throw if `threading.main_thread` does
200
                    # not support signals (e.g. embedded interpreter with signals
201
                    # not registered - see gh-91880)
202
                    sigint_handler = None
×
203
            else:
204
                sigint_handler = None
6✔
205

206
            self._interrupt_count = 0
6✔
207
            try:
6✔
208
                return self._loop.run_until_complete(task)
6✔
209
            except exceptions.CancelledError:
6✔
210
                if self._interrupt_count > 0:
×
211
                    uncancel = getattr(task, "uncancel", None)
×
212
                    if uncancel is not None and uncancel() == 0:
×
UNCOV
213
                        raise KeyboardInterrupt()
×
214
                raise  # CancelledError
×
215
            finally:
216
                if (
6✔
217
                    sigint_handler is not None
218
                    and signal.getsignal(signal.SIGINT) is sigint_handler
219
                ):
220
                    signal.signal(signal.SIGINT, signal.default_int_handler)
6✔
221

222
        def _lazy_init(self) -> None:
6✔
223
            if self._state is _State.CLOSED:
6✔
224
                raise RuntimeError("Runner is closed")
×
225
            if self._state is _State.INITIALIZED:
6✔
226
                return
6✔
227
            if self._loop_factory is None:
6✔
228
                self._loop = events.new_event_loop()
6✔
229
                if not self._set_event_loop:
6✔
230
                    # Call set_event_loop only once to avoid calling
231
                    # attach_loop multiple times on child watchers
232
                    events.set_event_loop(self._loop)
6✔
233
                    self._set_event_loop = True
6✔
234
            else:
235
                self._loop = self._loop_factory()
4✔
236
            if self._debug is not None:
6✔
237
                self._loop.set_debug(self._debug)
6✔
238
            self._context = contextvars.copy_context()
6✔
239
            self._state = _State.INITIALIZED
6✔
240

241
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
242
            self._interrupt_count += 1
×
243
            if self._interrupt_count == 1 and not main_task.done():
×
244
                main_task.cancel()
×
245
                # wakeup loop if it is blocked by select() with long timeout
246
                self._loop.call_soon_threadsafe(lambda: None)
×
247
                return
×
248
            raise KeyboardInterrupt()
×
249

250
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
251
        to_cancel = tasks.all_tasks(loop)
6✔
252
        if not to_cancel:
6✔
253
            return
6✔
254

255
        for task in to_cancel:
6✔
256
            task.cancel()
6✔
257

258
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
259

260
        for task in to_cancel:
6✔
261
            if task.cancelled():
6✔
262
                continue
6✔
263
            if task.exception() is not None:
5✔
264
                loop.call_exception_handler(
×
265
                    {
266
                        "message": "unhandled exception during asyncio.run() shutdown",
267
                        "exception": task.exception(),
268
                        "task": task,
269
                    }
270
                )
271

272
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
273
        """Schedule the shutdown of the default executor."""
274

275
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
276
            try:
3✔
277
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
278
                loop.call_soon_threadsafe(future.set_result, None)
3✔
279
            except Exception as ex:
×
280
                loop.call_soon_threadsafe(future.set_exception, ex)
×
281

282
        loop._executor_shutdown_called = True
3✔
283
        if loop._default_executor is None:
3✔
284
            return
3✔
285
        future = loop.create_future()
3✔
286
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
287
        thread.start()
3✔
288
        try:
3✔
289
            await future
3✔
290
        finally:
291
            thread.join()
3✔
292

293

294
T_Retval = TypeVar("T_Retval")
11✔
295
T_contra = TypeVar("T_contra", contravariant=True)
11✔
296
PosArgsT = TypeVarTuple("PosArgsT")
11✔
297
P = ParamSpec("P")
11✔
298

299
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
11✔
300

301

302
def find_root_task() -> asyncio.Task:
11✔
303
    root_task = _root_task.get(None)
11✔
304
    if root_task is not None and not root_task.done():
11✔
305
        return root_task
11✔
306

307
    # Look for a task that has been started via run_until_complete()
308
    for task in all_tasks():
11✔
309
        if task._callbacks and not task.done():
11✔
310
            callbacks = [cb for cb, context in task._callbacks]
11✔
311
            for cb in callbacks:
11✔
312
                if (
11✔
313
                    cb is _run_until_complete_cb
314
                    or getattr(cb, "__module__", None) == "uvloop.loop"
315
                ):
316
                    _root_task.set(task)
11✔
317
                    return task
11✔
318

319
    # Look up the topmost task in the AnyIO task tree, if possible
320
    task = cast(asyncio.Task, current_task())
10✔
321
    state = _task_states.get(task)
10✔
322
    if state:
10✔
323
        cancel_scope = state.cancel_scope
10✔
324
        while cancel_scope and cancel_scope._parent_scope is not None:
10✔
325
            cancel_scope = cancel_scope._parent_scope
×
326

327
        if cancel_scope is not None:
10✔
328
            return cast(asyncio.Task, cancel_scope._host_task)
10✔
329

330
    return task
×
331

332

333
def get_callable_name(func: Callable) -> str:
11✔
334
    module = getattr(func, "__module__", None)
11✔
335
    qualname = getattr(func, "__qualname__", None)
11✔
336
    return ".".join([x for x in (module, qualname) if x])
11✔
337

338

339
#
340
# Event loop
341
#
342

343
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
11✔
344

345

346
def _task_started(task: asyncio.Task) -> bool:
11✔
347
    """Return ``True`` if the task has been started and has not finished."""
348
    try:
11✔
349
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
11✔
350
    except AttributeError:
×
351
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
352
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
353

354

355
#
356
# Timeouts and cancellation
357
#
358

359

360
class CancelScope(BaseCancelScope):
11✔
361
    def __new__(
11✔
362
        cls, *, deadline: float = math.inf, shield: bool = False
363
    ) -> CancelScope:
364
        return object.__new__(cls)
11✔
365

366
    def __init__(self, deadline: float = math.inf, shield: bool = False):
11✔
367
        self._deadline = deadline
11✔
368
        self._shield = shield
11✔
369
        self._parent_scope: CancelScope | None = None
11✔
370
        self._child_scopes: set[CancelScope] = set()
11✔
371
        self._cancel_called = False
11✔
372
        self._cancelled_caught = False
11✔
373
        self._active = False
11✔
374
        self._timeout_handle: asyncio.TimerHandle | None = None
11✔
375
        self._cancel_handle: asyncio.Handle | None = None
11✔
376
        self._tasks: set[asyncio.Task] = set()
11✔
377
        self._host_task: asyncio.Task | None = None
11✔
378
        self._cancel_calls: int = 0
11✔
379
        self._cancelling: int | None = None
11✔
380

381
    def __enter__(self) -> CancelScope:
11✔
382
        if self._active:
11✔
383
            raise RuntimeError(
×
384
                "Each CancelScope may only be used for a single 'with' block"
385
            )
386

387
        self._host_task = host_task = cast(asyncio.Task, current_task())
11✔
388
        self._tasks.add(host_task)
11✔
389
        try:
11✔
390
            task_state = _task_states[host_task]
11✔
391
        except KeyError:
11✔
392
            task_state = TaskState(None, self)
11✔
393
            _task_states[host_task] = task_state
11✔
394
        else:
395
            self._parent_scope = task_state.cancel_scope
11✔
396
            task_state.cancel_scope = self
11✔
397
            if self._parent_scope is not None:
11✔
398
                self._parent_scope._child_scopes.add(self)
11✔
399
                self._parent_scope._tasks.remove(host_task)
11✔
400

401
        self._timeout()
11✔
402
        self._active = True
11✔
403
        if sys.version_info >= (3, 11):
11✔
404
            self._cancelling = self._host_task.cancelling()
5✔
405

406
        # Start cancelling the host task if the scope was cancelled before entering
407
        if self._cancel_called:
11✔
408
            self._deliver_cancellation(self)
11✔
409

410
        return self
11✔
411

412
    def __exit__(
11✔
413
        self,
414
        exc_type: type[BaseException] | None,
415
        exc_val: BaseException | None,
416
        exc_tb: TracebackType | None,
417
    ) -> bool | None:
418
        if not self._active:
11✔
419
            raise RuntimeError("This cancel scope is not active")
10✔
420
        if current_task() is not self._host_task:
11✔
421
            raise RuntimeError(
10✔
422
                "Attempted to exit cancel scope in a different task than it was "
423
                "entered in"
424
            )
425

426
        assert self._host_task is not None
11✔
427
        host_task_state = _task_states.get(self._host_task)
11✔
428
        if host_task_state is None or host_task_state.cancel_scope is not self:
11✔
429
            raise RuntimeError(
10✔
430
                "Attempted to exit a cancel scope that isn't the current tasks's "
431
                "current cancel scope"
432
            )
433

434
        self._active = False
11✔
435
        if self._timeout_handle:
11✔
436
            self._timeout_handle.cancel()
11✔
437
            self._timeout_handle = None
11✔
438

439
        self._tasks.remove(self._host_task)
11✔
440
        if self._parent_scope is not None:
11✔
441
            self._parent_scope._child_scopes.remove(self)
11✔
442
            self._parent_scope._tasks.add(self._host_task)
11✔
443

444
        host_task_state.cancel_scope = self._parent_scope
11✔
445

446
        # Restart the cancellation effort in the closest directly cancelled parent
447
        # scope if this one was shielded
448
        self._restart_cancellation_in_parent()
11✔
449

450
        if self._cancel_called and exc_val is not None:
11✔
451
            for exc in iterate_exceptions(exc_val):
11✔
452
                if isinstance(exc, CancelledError):
11✔
453
                    self._cancelled_caught = self._uncancel(exc)
11✔
454
                    if self._cancelled_caught:
11✔
455
                        break
11✔
456

457
            return self._cancelled_caught
11✔
458

459
        return None
11✔
460

461
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
11✔
462
        if sys.version_info < (3, 9) or self._host_task is None:
11✔
463
            self._cancel_calls = 0
3✔
464
            return True
3✔
465

466
        # Undo all cancellations done by this scope
467
        if self._cancelling is not None:
8✔
468
            while self._cancel_calls:
5✔
469
                self._cancel_calls -= 1
5✔
470
                if self._host_task.uncancel() <= self._cancelling:
5✔
471
                    return True
5✔
472

473
        self._cancel_calls = 0
8✔
474
        return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
8✔
475

476
    def _timeout(self) -> None:
11✔
477
        if self._deadline != math.inf:
11✔
478
            loop = get_running_loop()
11✔
479
            if loop.time() >= self._deadline:
11✔
480
                self.cancel()
11✔
481
            else:
482
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
11✔
483

484
    def _deliver_cancellation(self, origin: CancelScope) -> bool:
11✔
485
        """
486
        Deliver cancellation to directly contained tasks and nested cancel scopes.
487

488
        Schedule another run at the end if we still have tasks eligible for
489
        cancellation.
490

491
        :param origin: the cancel scope that originated the cancellation
492
        :return: ``True`` if the delivery needs to be retried on the next cycle
493

494
        """
495
        should_retry = False
11✔
496
        current = current_task()
11✔
497
        for task in self._tasks:
11✔
498
            if task._must_cancel:  # type: ignore[attr-defined]
11✔
499
                continue
10✔
500

501
            # The task is eligible for cancellation if it has started
502
            should_retry = True
11✔
503
            if task is not current and (task is self._host_task or _task_started(task)):
11✔
504
                waiter = task._fut_waiter  # type: ignore[attr-defined]
11✔
505
                if not isinstance(waiter, asyncio.Future) or not waiter.done():
11✔
506
                    origin._cancel_calls += 1
11✔
507
                    if sys.version_info >= (3, 9):
11✔
508
                        task.cancel(f"Cancelled by cancel scope {id(origin):x}")
8✔
509
                    else:
510
                        task.cancel()
3✔
511

512
        # Deliver cancellation to child scopes that aren't shielded or running their own
513
        # cancellation callbacks
514
        for scope in self._child_scopes:
11✔
515
            if not scope._shield and not scope.cancel_called:
11✔
516
                should_retry = scope._deliver_cancellation(origin) or should_retry
11✔
517

518
        # Schedule another callback if there are still tasks left
519
        if origin is self:
11✔
520
            if should_retry:
11✔
521
                self._cancel_handle = get_running_loop().call_soon(
11✔
522
                    self._deliver_cancellation, origin
523
                )
524
            else:
525
                self._cancel_handle = None
11✔
526

527
        return should_retry
11✔
528

529
    def _restart_cancellation_in_parent(self) -> None:
11✔
530
        """
531
        Restart the cancellation effort in the closest directly cancelled parent scope.
532

533
        """
534
        scope = self._parent_scope
11✔
535
        while scope is not None:
11✔
536
            if scope._cancel_called:
11✔
537
                if scope._cancel_handle is None:
11✔
538
                    scope._deliver_cancellation(scope)
11✔
539

540
                break
11✔
541

542
            # No point in looking beyond any shielded scope
543
            if scope._shield:
11✔
544
                break
10✔
545

546
            scope = scope._parent_scope
11✔
547

548
    def _parent_cancelled(self) -> bool:
11✔
549
        # Check whether any parent has been cancelled
550
        cancel_scope = self._parent_scope
11✔
551
        while cancel_scope is not None and not cancel_scope._shield:
11✔
552
            if cancel_scope._cancel_called:
11✔
553
                return True
11✔
554
            else:
555
                cancel_scope = cancel_scope._parent_scope
11✔
556

557
        return False
11✔
558

559
    def cancel(self) -> None:
11✔
560
        if not self._cancel_called:
11✔
561
            if self._timeout_handle:
11✔
562
                self._timeout_handle.cancel()
11✔
563
                self._timeout_handle = None
11✔
564

565
            self._cancel_called = True
11✔
566
            if self._host_task is not None:
11✔
567
                self._deliver_cancellation(self)
11✔
568

569
    @property
11✔
570
    def deadline(self) -> float:
11✔
571
        return self._deadline
10✔
572

573
    @deadline.setter
11✔
574
    def deadline(self, value: float) -> None:
11✔
575
        self._deadline = float(value)
10✔
576
        if self._timeout_handle is not None:
10✔
577
            self._timeout_handle.cancel()
10✔
578
            self._timeout_handle = None
10✔
579

580
        if self._active and not self._cancel_called:
10✔
581
            self._timeout()
10✔
582

583
    @property
11✔
584
    def cancel_called(self) -> bool:
11✔
585
        return self._cancel_called
11✔
586

587
    @property
11✔
588
    def cancelled_caught(self) -> bool:
11✔
589
        return self._cancelled_caught
11✔
590

591
    @property
11✔
592
    def shield(self) -> bool:
11✔
593
        return self._shield
11✔
594

595
    @shield.setter
11✔
596
    def shield(self, value: bool) -> None:
11✔
597
        if self._shield != value:
10✔
598
            self._shield = value
10✔
599
            if not value:
10✔
600
                self._restart_cancellation_in_parent()
10✔
601

602

603
#
604
# Task states
605
#
606

607

608
class TaskState:
11✔
609
    """
610
    Encapsulates auxiliary task information that cannot be added to the Task instance
611
    itself because there are no guarantees about its implementation.
612
    """
613

614
    __slots__ = "parent_id", "cancel_scope", "__weakref__"
11✔
615

616
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
11✔
617
        self.parent_id = parent_id
11✔
618
        self.cancel_scope = cancel_scope
11✔
619

620

621
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
11✔
622

623

624
#
625
# Task groups
626
#
627

628

629
class _AsyncioTaskStatus(abc.TaskStatus):
11✔
630
    def __init__(self, future: asyncio.Future, parent_id: int):
11✔
631
        self._future = future
11✔
632
        self._parent_id = parent_id
11✔
633

634
    def started(self, value: T_contra | None = None) -> None:
11✔
635
        try:
11✔
636
            self._future.set_result(value)
11✔
637
        except asyncio.InvalidStateError:
10✔
638
            if not self._future.cancelled():
10✔
639
                raise RuntimeError(
10✔
640
                    "called 'started' twice on the same task status"
641
                ) from None
642

643
        task = cast(asyncio.Task, current_task())
11✔
644
        _task_states[task].parent_id = self._parent_id
11✔
645

646

647
class TaskGroup(abc.TaskGroup):
11✔
648
    def __init__(self) -> None:
11✔
649
        self.cancel_scope: CancelScope = CancelScope()
11✔
650
        self._active = False
11✔
651
        self._exceptions: list[BaseException] = []
11✔
652
        self._tasks: set[asyncio.Task] = set()
11✔
653

654
    async def __aenter__(self) -> TaskGroup:
11✔
655
        self.cancel_scope.__enter__()
11✔
656
        self._active = True
11✔
657
        return self
11✔
658

659
    async def __aexit__(
11✔
660
        self,
661
        exc_type: type[BaseException] | None,
662
        exc_val: BaseException | None,
663
        exc_tb: TracebackType | None,
664
    ) -> bool | None:
665
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
11✔
666
        if exc_val is not None:
11✔
667
            self.cancel_scope.cancel()
11✔
668
            if not isinstance(exc_val, CancelledError):
11✔
669
                self._exceptions.append(exc_val)
11✔
670

671
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
11✔
672
        while self._tasks:
11✔
673
            try:
11✔
674
                await asyncio.wait(self._tasks)
11✔
675
            except CancelledError as exc:
11✔
676
                # This task was cancelled natively; reraise the CancelledError later
677
                # unless this task was already interrupted by another exception
678
                self.cancel_scope.cancel()
11✔
679
                if cancelled_exc_while_waiting_tasks is None:
11✔
680
                    cancelled_exc_while_waiting_tasks = exc
11✔
681

682
        self._active = False
11✔
683
        if self._exceptions:
11✔
684
            raise BaseExceptionGroup(
11✔
685
                "unhandled errors in a TaskGroup", self._exceptions
686
            )
687

688
        # Raise the CancelledError received while waiting for child tasks to exit,
689
        # unless the context manager itself was previously exited with another
690
        # exception, or if any of the  child tasks raised an exception other than
691
        # CancelledError
692
        if cancelled_exc_while_waiting_tasks:
11✔
693
            if exc_val is None or ignore_exception:
11✔
694
                raise cancelled_exc_while_waiting_tasks
11✔
695

696
        return ignore_exception
11✔
697

698
    def _spawn(
11✔
699
        self,
700
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
701
        args: tuple[Unpack[PosArgsT]],
702
        name: object,
703
        task_status_future: asyncio.Future | None = None,
704
    ) -> asyncio.Task:
705
        def task_done(_task: asyncio.Task) -> None:
11✔
706
            task_state = _task_states[_task]
11✔
707
            assert task_state.cancel_scope is not None
11✔
708
            assert _task in task_state.cancel_scope._tasks
11✔
709
            task_state.cancel_scope._tasks.remove(_task)
11✔
710
            self._tasks.remove(task)
11✔
711
            del _task_states[_task]
11✔
712

713
            try:
11✔
714
                exc = _task.exception()
11✔
715
            except CancelledError as e:
11✔
716
                while isinstance(e.__context__, CancelledError):
11✔
717
                    e = e.__context__
3✔
718

719
                exc = e
11✔
720

721
            if exc is not None:
11✔
722
                # The future can only be in the cancelled state if the host task was
723
                # cancelled, so return immediately instead of adding one more
724
                # CancelledError to the exceptions list
725
                if task_status_future is not None and task_status_future.cancelled():
11✔
726
                    return
10✔
727

728
                if task_status_future is None or task_status_future.done():
11✔
729
                    if not isinstance(exc, CancelledError):
11✔
730
                        self._exceptions.append(exc)
11✔
731

732
                    if not self.cancel_scope._parent_cancelled():
11✔
733
                        self.cancel_scope.cancel()
11✔
734
                else:
735
                    task_status_future.set_exception(exc)
10✔
736
            elif task_status_future is not None and not task_status_future.done():
11✔
737
                task_status_future.set_exception(
10✔
738
                    RuntimeError("Child exited without calling task_status.started()")
739
                )
740

741
        if not self._active:
11✔
742
            raise RuntimeError(
10✔
743
                "This task group is not active; no new tasks can be started."
744
            )
745

746
        kwargs = {}
11✔
747
        if task_status_future:
11✔
748
            parent_id = id(current_task())
11✔
749
            kwargs["task_status"] = _AsyncioTaskStatus(
11✔
750
                task_status_future, id(self.cancel_scope._host_task)
751
            )
752
        else:
753
            parent_id = id(self.cancel_scope._host_task)
11✔
754

755
        coro = func(*args, **kwargs)
11✔
756
        if not iscoroutine(coro):
11✔
757
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
10✔
758
            raise TypeError(
10✔
759
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
760
                f"the return value ({coro!r}) is not a coroutine object"
761
            )
762

763
        name = get_callable_name(func) if name is None else str(name)
11✔
764
        task = create_task(coro, name=name)
11✔
765
        task.add_done_callback(task_done)
11✔
766

767
        # Make the spawned task inherit the task group's cancel scope
768
        _task_states[task] = TaskState(
11✔
769
            parent_id=parent_id, cancel_scope=self.cancel_scope
770
        )
771
        self.cancel_scope._tasks.add(task)
11✔
772
        self._tasks.add(task)
11✔
773
        return task
11✔
774

775
    def start_soon(
11✔
776
        self,
777
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
778
        *args: Unpack[PosArgsT],
779
        name: object = None,
780
    ) -> None:
781
        self._spawn(func, args, name)
11✔
782

783
    async def start(
11✔
784
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
785
    ) -> Any:
786
        future: asyncio.Future = asyncio.Future()
11✔
787
        task = self._spawn(func, args, name, future)
11✔
788

789
        # If the task raises an exception after sending a start value without a switch
790
        # point between, the task group is cancelled and this method never proceeds to
791
        # process the completed future. That's why we have to have a shielded cancel
792
        # scope here.
793
        try:
11✔
794
            return await future
11✔
795
        except CancelledError:
10✔
796
            # Cancel the task and wait for it to exit before returning
797
            task.cancel()
10✔
798
            with CancelScope(shield=True), suppress(CancelledError):
10✔
799
                await task
10✔
800

801
            raise
10✔
802

803

804
#
805
# Threads
806
#
807

808
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
11✔
809

810

811
class WorkerThread(Thread):
11✔
812
    MAX_IDLE_TIME = 10  # seconds
11✔
813

814
    def __init__(
11✔
815
        self,
816
        root_task: asyncio.Task,
817
        workers: set[WorkerThread],
818
        idle_workers: deque[WorkerThread],
819
    ):
820
        super().__init__(name="AnyIO worker thread")
11✔
821
        self.root_task = root_task
11✔
822
        self.workers = workers
11✔
823
        self.idle_workers = idle_workers
11✔
824
        self.loop = root_task._loop
11✔
825
        self.queue: Queue[
11✔
826
            tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
827
        ] = Queue(2)
828
        self.idle_since = AsyncIOBackend.current_time()
11✔
829
        self.stopping = False
11✔
830

831
    def _report_result(
11✔
832
        self, future: asyncio.Future, result: Any, exc: BaseException | None
833
    ) -> None:
834
        self.idle_since = AsyncIOBackend.current_time()
11✔
835
        if not self.stopping:
11✔
836
            self.idle_workers.append(self)
11✔
837

838
        if not future.cancelled():
11✔
839
            if exc is not None:
11✔
840
                if isinstance(exc, StopIteration):
11✔
841
                    new_exc = RuntimeError("coroutine raised StopIteration")
10✔
842
                    new_exc.__cause__ = exc
10✔
843
                    exc = new_exc
10✔
844

845
                future.set_exception(exc)
11✔
846
            else:
847
                future.set_result(result)
11✔
848

849
    def run(self) -> None:
11✔
850
        with claim_worker_thread(AsyncIOBackend, self.loop):
11✔
851
            while True:
7✔
852
                item = self.queue.get()
11✔
853
                if item is None:
11✔
854
                    # Shutdown command received
855
                    return
11✔
856

857
                context, func, args, future, cancel_scope = item
11✔
858
                if not future.cancelled():
11✔
859
                    result = None
11✔
860
                    exception: BaseException | None = None
11✔
861
                    threadlocals.current_cancel_scope = cancel_scope
11✔
862
                    try:
11✔
863
                        result = context.run(func, *args)
11✔
864
                    except BaseException as exc:
11✔
865
                        exception = exc
11✔
866
                    finally:
867
                        del threadlocals.current_cancel_scope
11✔
868

869
                    if not self.loop.is_closed():
11✔
870
                        self.loop.call_soon_threadsafe(
11✔
871
                            self._report_result, future, result, exception
872
                        )
873

874
                self.queue.task_done()
11✔
875

876
    def stop(self, f: asyncio.Task | None = None) -> None:
11✔
877
        self.stopping = True
11✔
878
        self.queue.put_nowait(None)
11✔
879
        self.workers.discard(self)
11✔
880
        try:
11✔
881
            self.idle_workers.remove(self)
11✔
882
        except ValueError:
11✔
883
            pass
11✔
884

885

886
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
11✔
887
    "_threadpool_idle_workers"
888
)
889
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
11✔
890

891

892
class BlockingPortal(abc.BlockingPortal):
11✔
893
    def __new__(cls) -> BlockingPortal:
11✔
894
        return object.__new__(cls)
11✔
895

896
    def __init__(self) -> None:
11✔
897
        super().__init__()
11✔
898
        self._loop = get_running_loop()
11✔
899

900
    def _spawn_task_from_thread(
11✔
901
        self,
902
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
903
        args: tuple[Unpack[PosArgsT]],
904
        kwargs: dict[str, Any],
905
        name: object,
906
        future: Future[T_Retval],
907
    ) -> None:
908
        AsyncIOBackend.run_sync_from_thread(
11✔
909
            partial(self._task_group.start_soon, name=name),
910
            (self._call_func, func, args, kwargs, future),
911
            self._loop,
912
        )
913

914

915
#
916
# Subprocesses
917
#
918

919

920
@dataclass(eq=False)
11✔
921
class StreamReaderWrapper(abc.ByteReceiveStream):
11✔
922
    _stream: asyncio.StreamReader
11✔
923

924
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
925
        data = await self._stream.read(max_bytes)
10✔
926
        if data:
10✔
927
            return data
10✔
928
        else:
929
            raise EndOfStream
10✔
930

931
    async def aclose(self) -> None:
11✔
932
        self._stream.feed_eof()
10✔
933
        await AsyncIOBackend.checkpoint()
10✔
934

935

936
@dataclass(eq=False)
11✔
937
class StreamWriterWrapper(abc.ByteSendStream):
11✔
938
    _stream: asyncio.StreamWriter
11✔
939

940
    async def send(self, item: bytes) -> None:
11✔
941
        self._stream.write(item)
10✔
942
        await self._stream.drain()
10✔
943

944
    async def aclose(self) -> None:
11✔
945
        self._stream.close()
10✔
946
        await AsyncIOBackend.checkpoint()
10✔
947

948

949
@dataclass(eq=False)
11✔
950
class Process(abc.Process):
11✔
951
    _process: asyncio.subprocess.Process
11✔
952
    _stdin: StreamWriterWrapper | None
11✔
953
    _stdout: StreamReaderWrapper | None
11✔
954
    _stderr: StreamReaderWrapper | None
11✔
955

956
    async def aclose(self) -> None:
11✔
957
        with CancelScope(shield=True):
10✔
958
            if self._stdin:
10✔
959
                await self._stdin.aclose()
10✔
960
            if self._stdout:
10✔
961
                await self._stdout.aclose()
10✔
962
            if self._stderr:
10✔
963
                await self._stderr.aclose()
10✔
964

965
        try:
10✔
966
            await self.wait()
10✔
967
        except BaseException:
10✔
968
            self.kill()
10✔
969
            with CancelScope(shield=True):
10✔
970
                await self.wait()
10✔
971

972
            raise
10✔
973

974
    async def wait(self) -> int:
11✔
975
        return await self._process.wait()
10✔
976

977
    def terminate(self) -> None:
11✔
978
        self._process.terminate()
8✔
979

980
    def kill(self) -> None:
11✔
981
        self._process.kill()
10✔
982

983
    def send_signal(self, signal: int) -> None:
11✔
984
        self._process.send_signal(signal)
×
985

986
    @property
11✔
987
    def pid(self) -> int:
11✔
988
        return self._process.pid
×
989

990
    @property
11✔
991
    def returncode(self) -> int | None:
11✔
992
        return self._process.returncode
10✔
993

994
    @property
11✔
995
    def stdin(self) -> abc.ByteSendStream | None:
11✔
996
        return self._stdin
10✔
997

998
    @property
11✔
999
    def stdout(self) -> abc.ByteReceiveStream | None:
11✔
1000
        return self._stdout
10✔
1001

1002
    @property
11✔
1003
    def stderr(self) -> abc.ByteReceiveStream | None:
11✔
1004
        return self._stderr
10✔
1005

1006

1007
def _forcibly_shutdown_process_pool_on_exit(
11✔
1008
    workers: set[Process], _task: object
1009
) -> None:
1010
    """
1011
    Forcibly shuts down worker processes belonging to this event loop."""
1012
    child_watcher: asyncio.AbstractChildWatcher | None = None
10✔
1013
    if sys.version_info < (3, 12):
10✔
1014
        try:
6✔
1015
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
6✔
1016
        except NotImplementedError:
1✔
1017
            pass
1✔
1018

1019
    # Close as much as possible (w/o async/await) to avoid warnings
1020
    for process in workers:
10✔
1021
        if process.returncode is None:
10✔
1022
            continue
10✔
1023

1024
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
1025
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
1026
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
1027
        process.kill()
×
1028
        if child_watcher:
×
1029
            child_watcher.remove_child_handler(process.pid)
×
1030

1031

1032
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
11✔
1033
    """
1034
    Shuts down worker processes belonging to this event loop.
1035

1036
    NOTE: this only works when the event loop was started using asyncio.run() or
1037
    anyio.run().
1038

1039
    """
1040
    process: abc.Process
1041
    try:
10✔
1042
        await sleep(math.inf)
10✔
1043
    except asyncio.CancelledError:
10✔
1044
        for process in workers:
10✔
1045
            if process.returncode is None:
10✔
1046
                process.kill()
10✔
1047

1048
        for process in workers:
10✔
1049
            await process.aclose()
10✔
1050

1051

1052
#
1053
# Sockets and networking
1054
#
1055

1056

1057
class StreamProtocol(asyncio.Protocol):
11✔
1058
    read_queue: deque[bytes]
11✔
1059
    read_event: asyncio.Event
11✔
1060
    write_event: asyncio.Event
11✔
1061
    exception: Exception | None = None
11✔
1062
    is_at_eof: bool = False
11✔
1063

1064
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
1065
        self.read_queue = deque()
11✔
1066
        self.read_event = asyncio.Event()
11✔
1067
        self.write_event = asyncio.Event()
11✔
1068
        self.write_event.set()
11✔
1069
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
11✔
1070

1071
    def connection_lost(self, exc: Exception | None) -> None:
11✔
1072
        if exc:
11✔
1073
            self.exception = BrokenResourceError()
11✔
1074
            self.exception.__cause__ = exc
11✔
1075

1076
        self.read_event.set()
11✔
1077
        self.write_event.set()
11✔
1078

1079
    def data_received(self, data: bytes) -> None:
11✔
1080
        # ProactorEventloop sometimes sends bytearray instead of bytes
1081
        self.read_queue.append(bytes(data))
11✔
1082
        self.read_event.set()
11✔
1083

1084
    def eof_received(self) -> bool | None:
11✔
1085
        self.is_at_eof = True
11✔
1086
        self.read_event.set()
11✔
1087
        return True
11✔
1088

1089
    def pause_writing(self) -> None:
11✔
1090
        self.write_event = asyncio.Event()
11✔
1091

1092
    def resume_writing(self) -> None:
11✔
1093
        self.write_event.set()
1✔
1094

1095

1096
class DatagramProtocol(asyncio.DatagramProtocol):
11✔
1097
    read_queue: deque[tuple[bytes, IPSockAddrType]]
11✔
1098
    read_event: asyncio.Event
11✔
1099
    write_event: asyncio.Event
11✔
1100
    exception: Exception | None = None
11✔
1101

1102
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
1103
        self.read_queue = deque(maxlen=100)  # arbitrary value
10✔
1104
        self.read_event = asyncio.Event()
10✔
1105
        self.write_event = asyncio.Event()
10✔
1106
        self.write_event.set()
10✔
1107

1108
    def connection_lost(self, exc: Exception | None) -> None:
11✔
1109
        self.read_event.set()
10✔
1110
        self.write_event.set()
10✔
1111

1112
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
11✔
1113
        addr = convert_ipv6_sockaddr(addr)
10✔
1114
        self.read_queue.append((data, addr))
10✔
1115
        self.read_event.set()
10✔
1116

1117
    def error_received(self, exc: Exception) -> None:
11✔
1118
        self.exception = exc
×
1119

1120
    def pause_writing(self) -> None:
11✔
1121
        self.write_event.clear()
×
1122

1123
    def resume_writing(self) -> None:
11✔
1124
        self.write_event.set()
×
1125

1126

1127
class SocketStream(abc.SocketStream):
11✔
1128
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
11✔
1129
        self._transport = transport
11✔
1130
        self._protocol = protocol
11✔
1131
        self._receive_guard = ResourceGuard("reading from")
11✔
1132
        self._send_guard = ResourceGuard("writing to")
11✔
1133
        self._closed = False
11✔
1134

1135
    @property
11✔
1136
    def _raw_socket(self) -> socket.socket:
11✔
1137
        return self._transport.get_extra_info("socket")
11✔
1138

1139
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
1140
        with self._receive_guard:
11✔
1141
            if (
11✔
1142
                not self._protocol.read_event.is_set()
1143
                and not self._transport.is_closing()
1144
                and not self._protocol.is_at_eof
1145
            ):
1146
                self._transport.resume_reading()
11✔
1147
                await self._protocol.read_event.wait()
11✔
1148
                self._transport.pause_reading()
11✔
1149
            else:
1150
                await AsyncIOBackend.checkpoint()
11✔
1151

1152
            try:
11✔
1153
                chunk = self._protocol.read_queue.popleft()
11✔
1154
            except IndexError:
11✔
1155
                if self._closed:
11✔
1156
                    raise ClosedResourceError from None
11✔
1157
                elif self._protocol.exception:
11✔
1158
                    raise self._protocol.exception from None
11✔
1159
                else:
1160
                    raise EndOfStream from None
11✔
1161

1162
            if len(chunk) > max_bytes:
11✔
1163
                # Split the oversized chunk
1164
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
9✔
1165
                self._protocol.read_queue.appendleft(leftover)
9✔
1166

1167
            # If the read queue is empty, clear the flag so that the next call will
1168
            # block until data is available
1169
            if not self._protocol.read_queue:
11✔
1170
                self._protocol.read_event.clear()
11✔
1171

1172
        return chunk
11✔
1173

1174
    async def send(self, item: bytes) -> None:
11✔
1175
        with self._send_guard:
11✔
1176
            await AsyncIOBackend.checkpoint()
11✔
1177

1178
            if self._closed:
11✔
1179
                raise ClosedResourceError
11✔
1180
            elif self._protocol.exception is not None:
11✔
1181
                raise self._protocol.exception
11✔
1182

1183
            try:
11✔
1184
                self._transport.write(item)
11✔
1185
            except RuntimeError as exc:
×
1186
                if self._transport.is_closing():
×
1187
                    raise BrokenResourceError from exc
×
1188
                else:
1189
                    raise
×
1190

1191
            await self._protocol.write_event.wait()
11✔
1192

1193
    async def send_eof(self) -> None:
11✔
1194
        try:
11✔
1195
            self._transport.write_eof()
11✔
1196
        except OSError:
×
1197
            pass
×
1198

1199
    async def aclose(self) -> None:
11✔
1200
        if not self._transport.is_closing():
11✔
1201
            self._closed = True
11✔
1202
            try:
11✔
1203
                self._transport.write_eof()
11✔
1204
            except OSError:
7✔
1205
                pass
7✔
1206

1207
            self._transport.close()
11✔
1208
            await sleep(0)
11✔
1209
            self._transport.abort()
11✔
1210

1211

1212
class _RawSocketMixin:
11✔
1213
    _receive_future: asyncio.Future | None = None
11✔
1214
    _send_future: asyncio.Future | None = None
11✔
1215
    _closing = False
11✔
1216

1217
    def __init__(self, raw_socket: socket.socket):
11✔
1218
        self.__raw_socket = raw_socket
8✔
1219
        self._receive_guard = ResourceGuard("reading from")
8✔
1220
        self._send_guard = ResourceGuard("writing to")
8✔
1221

1222
    @property
11✔
1223
    def _raw_socket(self) -> socket.socket:
11✔
1224
        return self.__raw_socket
8✔
1225

1226
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1227
        def callback(f: object) -> None:
8✔
1228
            del self._receive_future
8✔
1229
            loop.remove_reader(self.__raw_socket)
8✔
1230

1231
        f = self._receive_future = asyncio.Future()
8✔
1232
        loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1233
        f.add_done_callback(callback)
8✔
1234
        return f
8✔
1235

1236
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1237
        def callback(f: object) -> None:
8✔
1238
            del self._send_future
8✔
1239
            loop.remove_writer(self.__raw_socket)
8✔
1240

1241
        f = self._send_future = asyncio.Future()
8✔
1242
        loop.add_writer(self.__raw_socket, f.set_result, None)
8✔
1243
        f.add_done_callback(callback)
8✔
1244
        return f
8✔
1245

1246
    async def aclose(self) -> None:
11✔
1247
        if not self._closing:
8✔
1248
            self._closing = True
8✔
1249
            if self.__raw_socket.fileno() != -1:
8✔
1250
                self.__raw_socket.close()
8✔
1251

1252
            if self._receive_future:
8✔
1253
                self._receive_future.set_result(None)
8✔
1254
            if self._send_future:
8✔
1255
                self._send_future.set_result(None)
×
1256

1257

1258
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
11✔
1259
    async def send_eof(self) -> None:
11✔
1260
        with self._send_guard:
8✔
1261
            self._raw_socket.shutdown(socket.SHUT_WR)
8✔
1262

1263
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
1264
        loop = get_running_loop()
8✔
1265
        await AsyncIOBackend.checkpoint()
8✔
1266
        with self._receive_guard:
8✔
1267
            while True:
5✔
1268
                try:
8✔
1269
                    data = self._raw_socket.recv(max_bytes)
8✔
1270
                except BlockingIOError:
8✔
1271
                    await self._wait_until_readable(loop)
8✔
1272
                except OSError as exc:
8✔
1273
                    if self._closing:
8✔
1274
                        raise ClosedResourceError from None
8✔
1275
                    else:
1276
                        raise BrokenResourceError from exc
1✔
1277
                else:
1278
                    if not data:
8✔
1279
                        raise EndOfStream
8✔
1280

1281
                    return data
8✔
1282

1283
    async def send(self, item: bytes) -> None:
11✔
1284
        loop = get_running_loop()
8✔
1285
        await AsyncIOBackend.checkpoint()
8✔
1286
        with self._send_guard:
8✔
1287
            view = memoryview(item)
8✔
1288
            while view:
8✔
1289
                try:
8✔
1290
                    bytes_sent = self._raw_socket.send(view)
8✔
1291
                except BlockingIOError:
8✔
1292
                    await self._wait_until_writable(loop)
8✔
1293
                except OSError as exc:
8✔
1294
                    if self._closing:
8✔
1295
                        raise ClosedResourceError from None
8✔
1296
                    else:
1297
                        raise BrokenResourceError from exc
1✔
1298
                else:
1299
                    view = view[bytes_sent:]
8✔
1300

1301
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
11✔
1302
        if not isinstance(msglen, int) or msglen < 0:
8✔
1303
            raise ValueError("msglen must be a non-negative integer")
8✔
1304
        if not isinstance(maxfds, int) or maxfds < 1:
8✔
1305
            raise ValueError("maxfds must be a positive integer")
8✔
1306

1307
        loop = get_running_loop()
8✔
1308
        fds = array.array("i")
8✔
1309
        await AsyncIOBackend.checkpoint()
8✔
1310
        with self._receive_guard:
8✔
1311
            while True:
5✔
1312
                try:
8✔
1313
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
8✔
1314
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1315
                    )
1316
                except BlockingIOError:
8✔
1317
                    await self._wait_until_readable(loop)
8✔
1318
                except OSError as exc:
×
1319
                    if self._closing:
×
1320
                        raise ClosedResourceError from None
×
1321
                    else:
1322
                        raise BrokenResourceError from exc
×
1323
                else:
1324
                    if not message and not ancdata:
8✔
1325
                        raise EndOfStream
×
1326

1327
                    break
5✔
1328

1329
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
8✔
1330
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
8✔
1331
                raise RuntimeError(
×
1332
                    f"Received unexpected ancillary data; message = {message!r}, "
1333
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1334
                )
1335

1336
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
8✔
1337

1338
        return message, list(fds)
8✔
1339

1340
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
11✔
1341
        if not message:
8✔
1342
            raise ValueError("message must not be empty")
8✔
1343
        if not fds:
8✔
1344
            raise ValueError("fds must not be empty")
8✔
1345

1346
        loop = get_running_loop()
8✔
1347
        filenos: list[int] = []
8✔
1348
        for fd in fds:
8✔
1349
            if isinstance(fd, int):
8✔
1350
                filenos.append(fd)
×
1351
            elif isinstance(fd, IOBase):
8✔
1352
                filenos.append(fd.fileno())
8✔
1353

1354
        fdarray = array.array("i", filenos)
8✔
1355
        await AsyncIOBackend.checkpoint()
8✔
1356
        with self._send_guard:
8✔
1357
            while True:
5✔
1358
                try:
8✔
1359
                    # The ignore can be removed after mypy picks up
1360
                    # https://github.com/python/typeshed/pull/5545
1361
                    self._raw_socket.sendmsg(
8✔
1362
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1363
                    )
1364
                    break
8✔
1365
                except BlockingIOError:
×
1366
                    await self._wait_until_writable(loop)
×
1367
                except OSError as exc:
×
1368
                    if self._closing:
×
1369
                        raise ClosedResourceError from None
×
1370
                    else:
1371
                        raise BrokenResourceError from exc
×
1372

1373

1374
class TCPSocketListener(abc.SocketListener):
11✔
1375
    _accept_scope: CancelScope | None = None
11✔
1376
    _closed = False
11✔
1377

1378
    def __init__(self, raw_socket: socket.socket):
11✔
1379
        self.__raw_socket = raw_socket
11✔
1380
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
11✔
1381
        self._accept_guard = ResourceGuard("accepting connections from")
11✔
1382

1383
    @property
11✔
1384
    def _raw_socket(self) -> socket.socket:
11✔
1385
        return self.__raw_socket
11✔
1386

1387
    async def accept(self) -> abc.SocketStream:
11✔
1388
        if self._closed:
11✔
1389
            raise ClosedResourceError
11✔
1390

1391
        with self._accept_guard:
11✔
1392
            await AsyncIOBackend.checkpoint()
11✔
1393
            with CancelScope() as self._accept_scope:
11✔
1394
                try:
11✔
1395
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
11✔
1396
                except asyncio.CancelledError:
11✔
1397
                    # Workaround for https://bugs.python.org/issue41317
1398
                    try:
11✔
1399
                        self._loop.remove_reader(self._raw_socket)
11✔
1400
                    except (ValueError, NotImplementedError):
2✔
1401
                        pass
2✔
1402

1403
                    if self._closed:
11✔
1404
                        raise ClosedResourceError from None
10✔
1405

1406
                    raise
11✔
1407
                finally:
1408
                    self._accept_scope = None
11✔
1409

1410
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
11✔
1411
        transport, protocol = await self._loop.connect_accepted_socket(
11✔
1412
            StreamProtocol, client_sock
1413
        )
1414
        return SocketStream(transport, protocol)
11✔
1415

1416
    async def aclose(self) -> None:
11✔
1417
        if self._closed:
11✔
1418
            return
11✔
1419

1420
        self._closed = True
11✔
1421
        if self._accept_scope:
11✔
1422
            # Workaround for https://bugs.python.org/issue41317
1423
            try:
11✔
1424
                self._loop.remove_reader(self._raw_socket)
11✔
1425
            except (ValueError, NotImplementedError):
2✔
1426
                pass
2✔
1427

1428
            self._accept_scope.cancel()
10✔
1429
            await sleep(0)
10✔
1430

1431
        self._raw_socket.close()
11✔
1432

1433

1434
class UNIXSocketListener(abc.SocketListener):
11✔
1435
    def __init__(self, raw_socket: socket.socket):
11✔
1436
        self.__raw_socket = raw_socket
8✔
1437
        self._loop = get_running_loop()
8✔
1438
        self._accept_guard = ResourceGuard("accepting connections from")
8✔
1439
        self._closed = False
8✔
1440

1441
    async def accept(self) -> abc.SocketStream:
11✔
1442
        await AsyncIOBackend.checkpoint()
8✔
1443
        with self._accept_guard:
8✔
1444
            while True:
5✔
1445
                try:
8✔
1446
                    client_sock, _ = self.__raw_socket.accept()
8✔
1447
                    client_sock.setblocking(False)
8✔
1448
                    return UNIXSocketStream(client_sock)
8✔
1449
                except BlockingIOError:
8✔
1450
                    f: asyncio.Future = asyncio.Future()
8✔
1451
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1452
                    f.add_done_callback(
8✔
1453
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1454
                    )
1455
                    await f
8✔
1456
                except OSError as exc:
×
1457
                    if self._closed:
×
1458
                        raise ClosedResourceError from None
×
1459
                    else:
1460
                        raise BrokenResourceError from exc
1✔
1461

1462
    async def aclose(self) -> None:
11✔
1463
        self._closed = True
8✔
1464
        self.__raw_socket.close()
8✔
1465

1466
    @property
11✔
1467
    def _raw_socket(self) -> socket.socket:
11✔
1468
        return self.__raw_socket
8✔
1469

1470

1471
class UDPSocket(abc.UDPSocket):
11✔
1472
    def __init__(
11✔
1473
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1474
    ):
1475
        self._transport = transport
10✔
1476
        self._protocol = protocol
10✔
1477
        self._receive_guard = ResourceGuard("reading from")
10✔
1478
        self._send_guard = ResourceGuard("writing to")
10✔
1479
        self._closed = False
10✔
1480

1481
    @property
11✔
1482
    def _raw_socket(self) -> socket.socket:
11✔
1483
        return self._transport.get_extra_info("socket")
10✔
1484

1485
    async def aclose(self) -> None:
11✔
1486
        if not self._transport.is_closing():
10✔
1487
            self._closed = True
10✔
1488
            self._transport.close()
10✔
1489

1490
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
11✔
1491
        with self._receive_guard:
10✔
1492
            await AsyncIOBackend.checkpoint()
10✔
1493

1494
            # If the buffer is empty, ask for more data
1495
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1496
                self._protocol.read_event.clear()
10✔
1497
                await self._protocol.read_event.wait()
10✔
1498

1499
            try:
10✔
1500
                return self._protocol.read_queue.popleft()
10✔
1501
            except IndexError:
10✔
1502
                if self._closed:
10✔
1503
                    raise ClosedResourceError from None
10✔
1504
                else:
1505
                    raise BrokenResourceError from None
1✔
1506

1507
    async def send(self, item: UDPPacketType) -> None:
11✔
1508
        with self._send_guard:
10✔
1509
            await AsyncIOBackend.checkpoint()
10✔
1510
            await self._protocol.write_event.wait()
10✔
1511
            if self._closed:
10✔
1512
                raise ClosedResourceError
10✔
1513
            elif self._transport.is_closing():
10✔
1514
                raise BrokenResourceError
×
1515
            else:
1516
                self._transport.sendto(*item)
10✔
1517

1518

1519
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
11✔
1520
    def __init__(
11✔
1521
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1522
    ):
1523
        self._transport = transport
10✔
1524
        self._protocol = protocol
10✔
1525
        self._receive_guard = ResourceGuard("reading from")
10✔
1526
        self._send_guard = ResourceGuard("writing to")
10✔
1527
        self._closed = False
10✔
1528

1529
    @property
11✔
1530
    def _raw_socket(self) -> socket.socket:
11✔
1531
        return self._transport.get_extra_info("socket")
10✔
1532

1533
    async def aclose(self) -> None:
11✔
1534
        if not self._transport.is_closing():
10✔
1535
            self._closed = True
10✔
1536
            self._transport.close()
10✔
1537

1538
    async def receive(self) -> bytes:
11✔
1539
        with self._receive_guard:
10✔
1540
            await AsyncIOBackend.checkpoint()
10✔
1541

1542
            # If the buffer is empty, ask for more data
1543
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1544
                self._protocol.read_event.clear()
10✔
1545
                await self._protocol.read_event.wait()
10✔
1546

1547
            try:
10✔
1548
                packet = self._protocol.read_queue.popleft()
10✔
1549
            except IndexError:
10✔
1550
                if self._closed:
10✔
1551
                    raise ClosedResourceError from None
10✔
1552
                else:
1553
                    raise BrokenResourceError from None
×
1554

1555
            return packet[0]
10✔
1556

1557
    async def send(self, item: bytes) -> None:
11✔
1558
        with self._send_guard:
10✔
1559
            await AsyncIOBackend.checkpoint()
10✔
1560
            await self._protocol.write_event.wait()
10✔
1561
            if self._closed:
10✔
1562
                raise ClosedResourceError
10✔
1563
            elif self._transport.is_closing():
10✔
1564
                raise BrokenResourceError
×
1565
            else:
1566
                self._transport.sendto(item)
10✔
1567

1568

1569
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
11✔
1570
    async def receive(self) -> UNIXDatagramPacketType:
11✔
1571
        loop = get_running_loop()
8✔
1572
        await AsyncIOBackend.checkpoint()
8✔
1573
        with self._receive_guard:
8✔
1574
            while True:
5✔
1575
                try:
8✔
1576
                    data = self._raw_socket.recvfrom(65536)
8✔
1577
                except BlockingIOError:
8✔
1578
                    await self._wait_until_readable(loop)
8✔
1579
                except OSError as exc:
8✔
1580
                    if self._closing:
8✔
1581
                        raise ClosedResourceError from None
8✔
1582
                    else:
1583
                        raise BrokenResourceError from exc
1✔
1584
                else:
1585
                    return data
8✔
1586

1587
    async def send(self, item: UNIXDatagramPacketType) -> None:
11✔
1588
        loop = get_running_loop()
8✔
1589
        await AsyncIOBackend.checkpoint()
8✔
1590
        with self._send_guard:
8✔
1591
            while True:
5✔
1592
                try:
8✔
1593
                    self._raw_socket.sendto(*item)
8✔
1594
                except BlockingIOError:
8✔
1595
                    await self._wait_until_writable(loop)
×
1596
                except OSError as exc:
8✔
1597
                    if self._closing:
8✔
1598
                        raise ClosedResourceError from None
8✔
1599
                    else:
1600
                        raise BrokenResourceError from exc
1✔
1601
                else:
1602
                    return
8✔
1603

1604

1605
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
11✔
1606
    async def receive(self) -> bytes:
11✔
1607
        loop = get_running_loop()
8✔
1608
        await AsyncIOBackend.checkpoint()
8✔
1609
        with self._receive_guard:
8✔
1610
            while True:
5✔
1611
                try:
8✔
1612
                    data = self._raw_socket.recv(65536)
8✔
1613
                except BlockingIOError:
8✔
1614
                    await self._wait_until_readable(loop)
8✔
1615
                except OSError as exc:
8✔
1616
                    if self._closing:
8✔
1617
                        raise ClosedResourceError from None
8✔
1618
                    else:
1619
                        raise BrokenResourceError from exc
1✔
1620
                else:
1621
                    return data
8✔
1622

1623
    async def send(self, item: bytes) -> None:
11✔
1624
        loop = get_running_loop()
8✔
1625
        await AsyncIOBackend.checkpoint()
8✔
1626
        with self._send_guard:
8✔
1627
            while True:
5✔
1628
                try:
8✔
1629
                    self._raw_socket.send(item)
8✔
1630
                except BlockingIOError:
8✔
1631
                    await self._wait_until_writable(loop)
×
1632
                except OSError as exc:
8✔
1633
                    if self._closing:
8✔
1634
                        raise ClosedResourceError from None
8✔
1635
                    else:
1636
                        raise BrokenResourceError from exc
1✔
1637
                else:
1638
                    return
8✔
1639

1640

1641
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
11✔
1642
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
11✔
1643

1644

1645
#
1646
# Synchronization
1647
#
1648

1649

1650
class Event(BaseEvent):
11✔
1651
    def __new__(cls) -> Event:
11✔
1652
        return object.__new__(cls)
11✔
1653

1654
    def __init__(self) -> None:
11✔
1655
        self._event = asyncio.Event()
11✔
1656

1657
    def set(self) -> None:
11✔
1658
        self._event.set()
11✔
1659

1660
    def is_set(self) -> bool:
11✔
1661
        return self._event.is_set()
11✔
1662

1663
    async def wait(self) -> None:
11✔
1664
        if self.is_set():
11✔
1665
            await AsyncIOBackend.checkpoint()
11✔
1666
        else:
1667
            await self._event.wait()
11✔
1668

1669
    def statistics(self) -> EventStatistics:
11✔
1670
        return EventStatistics(len(self._event._waiters))
10✔
1671

1672

1673
class Lock(BaseLock):
11✔
1674
    def __new__(cls) -> Lock:
11✔
1675
        return object.__new__(cls)
10✔
1676

1677
    def __init__(self) -> None:
11✔
1678
        self._owner_task: asyncio.Task | None = None
10✔
1679
        self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
10✔
1680

1681
    async def acquire(self) -> None:
11✔
1682
        if self._owner_task is None and not self._waiters:
10✔
1683
            await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1684
            self._owner_task = current_task()
10✔
1685
            try:
10✔
1686
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1687
            except CancelledError:
10✔
1688
                self.release()
10✔
1689
                raise
10✔
1690

1691
            return
10✔
1692

1693
        task = cast(asyncio.Task, current_task())
10✔
1694
        fut: asyncio.Future[None] = asyncio.Future()
10✔
1695
        item = task, fut
10✔
1696
        self._waiters.append(item)
10✔
1697
        try:
10✔
1698
            await fut
10✔
1699
        except CancelledError:
10✔
1700
            self._waiters.remove(item)
10✔
1701
            if self._owner_task is task:
10✔
1702
                self.release()
10✔
1703

1704
            raise
10✔
1705

1706
        self._waiters.remove(item)
10✔
1707

1708
    def acquire_nowait(self) -> None:
11✔
1709
        if self._owner_task is None and not self._waiters:
10✔
1710
            self._owner_task = current_task()
10✔
1711
            return
10✔
1712

1713
        raise WouldBlock
10✔
1714

1715
    def locked(self) -> bool:
11✔
1716
        return self._owner_task is not None
10✔
1717

1718
    def release(self) -> None:
11✔
1719
        if self._owner_task != current_task():
10✔
NEW
1720
            raise RuntimeError("The current task is not holding this lock")
×
1721

1722
        for task, fut in self._waiters:
10✔
1723
            if not fut.cancelled():
10✔
1724
                self._owner_task = task
10✔
1725
                fut.set_result(None)
10✔
1726
                return
10✔
1727

1728
        self._owner_task = None
10✔
1729

1730
    def statistics(self) -> LockStatistics:
11✔
1731
        task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
10✔
1732
        return LockStatistics(self.locked(), task_info, len(self._waiters))
10✔
1733

1734

1735
class Semaphore(BaseSemaphore):
11✔
1736
    def __new__(cls, initial_value: int, *, max_value: int | None = None) -> Semaphore:
11✔
1737
        return object.__new__(cls)
10✔
1738

1739
    def __init__(self, initial_value: int, *, max_value: int | None = None):
11✔
1740
        super().__init__(initial_value, max_value=max_value)
10✔
1741
        self._value = initial_value
10✔
1742
        self._max_value = max_value
10✔
1743
        self._waiters: deque[asyncio.Future[None]] = deque()
10✔
1744

1745
    async def acquire(self) -> None:
11✔
1746
        if self._value > 0 and not self._waiters:
10✔
1747
            await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1748
            self._value -= 1
10✔
1749
            try:
10✔
1750
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1751
            except CancelledError:
10✔
1752
                self.release()
10✔
1753
                raise
10✔
1754

1755
            return
10✔
1756

1757
        fut: asyncio.Future[None] = asyncio.Future()
10✔
1758
        self._waiters.append(fut)
10✔
1759
        try:
10✔
1760
            await fut
10✔
1761
        except CancelledError:
10✔
1762
            try:
10✔
1763
                self._waiters.remove(fut)
10✔
1764
            except ValueError:
10✔
1765
                self.release()
10✔
1766

1767
            raise
10✔
1768

1769
    def acquire_nowait(self) -> None:
11✔
1770
        if self._value == 0:
10✔
1771
            raise WouldBlock
10✔
1772

1773
        self._value -= 1
10✔
1774

1775
    def release(self) -> None:
11✔
1776
        if self._max_value is not None and self._value == self._max_value:
10✔
1777
            raise ValueError("semaphore released too many times")
10✔
1778

1779
        for fut in self._waiters:
10✔
1780
            if not fut.cancelled():
10✔
1781
                fut.set_result(None)
10✔
1782
                self._waiters.remove(fut)
10✔
1783
                return
10✔
1784

1785
        self._value += 1
10✔
1786

1787
    @property
11✔
1788
    def value(self) -> int:
11✔
1789
        return self._value
10✔
1790

1791
    @property
11✔
1792
    def max_value(self) -> int | None:
11✔
1793
        return self._max_value
10✔
1794

1795
    def statistics(self) -> SemaphoreStatistics:
11✔
1796
        return SemaphoreStatistics(len(self._waiters))
10✔
1797

1798

1799
class CapacityLimiter(BaseCapacityLimiter):
11✔
1800
    _total_tokens: float = 0
11✔
1801

1802
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
11✔
1803
        return object.__new__(cls)
11✔
1804

1805
    def __init__(self, total_tokens: float):
11✔
1806
        self._borrowers: set[Any] = set()
11✔
1807
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
11✔
1808
        self.total_tokens = total_tokens
11✔
1809

1810
    async def __aenter__(self) -> None:
11✔
1811
        await self.acquire()
11✔
1812

1813
    async def __aexit__(
11✔
1814
        self,
1815
        exc_type: type[BaseException] | None,
1816
        exc_val: BaseException | None,
1817
        exc_tb: TracebackType | None,
1818
    ) -> None:
1819
        self.release()
11✔
1820

1821
    @property
11✔
1822
    def total_tokens(self) -> float:
11✔
1823
        return self._total_tokens
10✔
1824

1825
    @total_tokens.setter
11✔
1826
    def total_tokens(self, value: float) -> None:
11✔
1827
        if not isinstance(value, int) and not math.isinf(value):
11✔
1828
            raise TypeError("total_tokens must be an int or math.inf")
10✔
1829
        if value < 1:
11✔
1830
            raise ValueError("total_tokens must be >= 1")
10✔
1831

1832
        waiters_to_notify = max(value - self._total_tokens, 0)
11✔
1833
        self._total_tokens = value
11✔
1834

1835
        # Notify waiting tasks that they have acquired the limiter
1836
        while self._wait_queue and waiters_to_notify:
11✔
1837
            event = self._wait_queue.popitem(last=False)[1]
10✔
1838
            event.set()
10✔
1839
            waiters_to_notify -= 1
10✔
1840

1841
    @property
11✔
1842
    def borrowed_tokens(self) -> int:
11✔
1843
        return len(self._borrowers)
10✔
1844

1845
    @property
11✔
1846
    def available_tokens(self) -> float:
11✔
1847
        return self._total_tokens - len(self._borrowers)
10✔
1848

1849
    def acquire_nowait(self) -> None:
11✔
1850
        self.acquire_on_behalf_of_nowait(current_task())
×
1851

1852
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
11✔
1853
        if borrower in self._borrowers:
11✔
1854
            raise RuntimeError(
10✔
1855
                "this borrower is already holding one of this CapacityLimiter's "
1856
                "tokens"
1857
            )
1858

1859
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
11✔
1860
            raise WouldBlock
10✔
1861

1862
        self._borrowers.add(borrower)
11✔
1863

1864
    async def acquire(self) -> None:
11✔
1865
        return await self.acquire_on_behalf_of(current_task())
11✔
1866

1867
    async def acquire_on_behalf_of(self, borrower: object) -> None:
11✔
1868
        await AsyncIOBackend.checkpoint_if_cancelled()
11✔
1869
        try:
11✔
1870
            self.acquire_on_behalf_of_nowait(borrower)
11✔
1871
        except WouldBlock:
10✔
1872
            event = asyncio.Event()
10✔
1873
            self._wait_queue[borrower] = event
10✔
1874
            try:
10✔
1875
                await event.wait()
10✔
1876
            except BaseException:
×
1877
                self._wait_queue.pop(borrower, None)
×
1878
                raise
×
1879

1880
            self._borrowers.add(borrower)
10✔
1881
        else:
1882
            try:
11✔
1883
                await AsyncIOBackend.cancel_shielded_checkpoint()
11✔
1884
            except BaseException:
10✔
1885
                self.release()
10✔
1886
                raise
10✔
1887

1888
    def release(self) -> None:
11✔
1889
        self.release_on_behalf_of(current_task())
11✔
1890

1891
    def release_on_behalf_of(self, borrower: object) -> None:
11✔
1892
        try:
11✔
1893
            self._borrowers.remove(borrower)
11✔
1894
        except KeyError:
10✔
1895
            raise RuntimeError(
10✔
1896
                "this borrower isn't holding any of this CapacityLimiter's tokens"
1897
            ) from None
1898

1899
        # Notify the next task in line if this limiter has free capacity now
1900
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
11✔
1901
            event = self._wait_queue.popitem(last=False)[1]
10✔
1902
            event.set()
10✔
1903

1904
    def statistics(self) -> CapacityLimiterStatistics:
11✔
1905
        return CapacityLimiterStatistics(
10✔
1906
            self.borrowed_tokens,
1907
            self.total_tokens,
1908
            tuple(self._borrowers),
1909
            len(self._wait_queue),
1910
        )
1911

1912

1913
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
11✔
1914

1915

1916
#
1917
# Operating system signals
1918
#
1919

1920

1921
class _SignalReceiver:
11✔
1922
    def __init__(self, signals: tuple[Signals, ...]):
11✔
1923
        self._signals = signals
9✔
1924
        self._loop = get_running_loop()
9✔
1925
        self._signal_queue: deque[Signals] = deque()
9✔
1926
        self._future: asyncio.Future = asyncio.Future()
9✔
1927
        self._handled_signals: set[Signals] = set()
9✔
1928

1929
    def _deliver(self, signum: Signals) -> None:
11✔
1930
        self._signal_queue.append(signum)
9✔
1931
        if not self._future.done():
9✔
1932
            self._future.set_result(None)
9✔
1933

1934
    def __enter__(self) -> _SignalReceiver:
11✔
1935
        for sig in set(self._signals):
9✔
1936
            self._loop.add_signal_handler(sig, self._deliver, sig)
9✔
1937
            self._handled_signals.add(sig)
9✔
1938

1939
        return self
9✔
1940

1941
    def __exit__(
11✔
1942
        self,
1943
        exc_type: type[BaseException] | None,
1944
        exc_val: BaseException | None,
1945
        exc_tb: TracebackType | None,
1946
    ) -> bool | None:
1947
        for sig in self._handled_signals:
9✔
1948
            self._loop.remove_signal_handler(sig)
9✔
1949
        return None
9✔
1950

1951
    def __aiter__(self) -> _SignalReceiver:
11✔
1952
        return self
9✔
1953

1954
    async def __anext__(self) -> Signals:
11✔
1955
        await AsyncIOBackend.checkpoint()
9✔
1956
        if not self._signal_queue:
9✔
1957
            self._future = asyncio.Future()
×
1958
            await self._future
×
1959

1960
        return self._signal_queue.popleft()
9✔
1961

1962

1963
#
1964
# Testing and debugging
1965
#
1966

1967

1968
class AsyncIOTaskInfo(TaskInfo):
11✔
1969
    def __init__(self, task: asyncio.Task):
11✔
1970
        task_state = _task_states.get(task)
11✔
1971
        if task_state is None:
11✔
1972
            parent_id = None
11✔
1973
        else:
1974
            parent_id = task_state.parent_id
11✔
1975

1976
        super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
11✔
1977
        self._task = weakref.ref(task)
11✔
1978

1979
    def has_pending_cancellation(self) -> bool:
11✔
1980
        if not (task := self._task()):
11✔
1981
            # If the task isn't around anymore, it won't have a pending cancellation
1982
            return False
×
1983

1984
        if sys.version_info >= (3, 11):
11✔
1985
            if task.cancelling():
5✔
1986
                return True
5✔
1987
        elif (
6✔
1988
            isinstance(task._fut_waiter, asyncio.Future)
1989
            and task._fut_waiter.cancelled()
1990
        ):
1991
            return True
6✔
1992

1993
        if task_state := _task_states.get(task):
11✔
1994
            if cancel_scope := task_state.cancel_scope:
11✔
1995
                return cancel_scope.cancel_called or (
11✔
1996
                    not cancel_scope.shield and cancel_scope._parent_cancelled()
1997
                )
1998

1999
        return False
11✔
2000

2001

2002
class TestRunner(abc.TestRunner):
11✔
2003
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
11✔
2004

2005
    def __init__(
11✔
2006
        self,
2007
        *,
2008
        debug: bool | None = None,
2009
        use_uvloop: bool = False,
2010
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
2011
    ) -> None:
2012
        if use_uvloop and loop_factory is None:
11✔
2013
            import uvloop
×
2014

2015
            loop_factory = uvloop.new_event_loop
×
2016

2017
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
11✔
2018
        self._exceptions: list[BaseException] = []
11✔
2019
        self._runner_task: asyncio.Task | None = None
11✔
2020

2021
    def __enter__(self) -> TestRunner:
11✔
2022
        self._runner.__enter__()
11✔
2023
        self.get_loop().set_exception_handler(self._exception_handler)
11✔
2024
        return self
11✔
2025

2026
    def __exit__(
11✔
2027
        self,
2028
        exc_type: type[BaseException] | None,
2029
        exc_val: BaseException | None,
2030
        exc_tb: TracebackType | None,
2031
    ) -> None:
2032
        self._runner.__exit__(exc_type, exc_val, exc_tb)
11✔
2033

2034
    def get_loop(self) -> AbstractEventLoop:
11✔
2035
        return self._runner.get_loop()
11✔
2036

2037
    def _exception_handler(
11✔
2038
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
2039
    ) -> None:
2040
        if isinstance(context.get("exception"), Exception):
11✔
2041
            self._exceptions.append(context["exception"])
11✔
2042
        else:
2043
            loop.default_exception_handler(context)
11✔
2044

2045
    def _raise_async_exceptions(self) -> None:
11✔
2046
        # Re-raise any exceptions raised in asynchronous callbacks
2047
        if self._exceptions:
11✔
2048
            exceptions, self._exceptions = self._exceptions, []
11✔
2049
            if len(exceptions) == 1:
11✔
2050
                raise exceptions[0]
11✔
2051
            elif exceptions:
×
2052
                raise BaseExceptionGroup(
×
2053
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
2054
                )
2055

2056
    async def _run_tests_and_fixtures(
11✔
2057
        self,
2058
        receive_stream: MemoryObjectReceiveStream[
2059
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
2060
        ],
2061
    ) -> None:
2062
        with receive_stream, self._send_stream:
11✔
2063
            async for coro, future in receive_stream:
11✔
2064
                try:
11✔
2065
                    retval = await coro
11✔
2066
                except BaseException as exc:
11✔
2067
                    if not future.cancelled():
11✔
2068
                        future.set_exception(exc)
11✔
2069
                else:
2070
                    if not future.cancelled():
11✔
2071
                        future.set_result(retval)
11✔
2072

2073
    async def _call_in_runner_task(
11✔
2074
        self,
2075
        func: Callable[P, Awaitable[T_Retval]],
2076
        *args: P.args,
2077
        **kwargs: P.kwargs,
2078
    ) -> T_Retval:
2079
        if not self._runner_task:
11✔
2080
            self._send_stream, receive_stream = create_memory_object_stream[
11✔
2081
                Tuple[Awaitable[Any], asyncio.Future]
2082
            ](1)
2083
            self._runner_task = self.get_loop().create_task(
11✔
2084
                self._run_tests_and_fixtures(receive_stream)
2085
            )
2086

2087
        coro = func(*args, **kwargs)
11✔
2088
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
11✔
2089
        self._send_stream.send_nowait((coro, future))
11✔
2090
        return await future
11✔
2091

2092
    def run_asyncgen_fixture(
11✔
2093
        self,
2094
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
2095
        kwargs: dict[str, Any],
2096
    ) -> Iterable[T_Retval]:
2097
        asyncgen = fixture_func(**kwargs)
11✔
2098
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
11✔
2099
            self._call_in_runner_task(asyncgen.asend, None)
2100
        )
2101
        self._raise_async_exceptions()
11✔
2102

2103
        yield fixturevalue
11✔
2104

2105
        try:
11✔
2106
            self.get_loop().run_until_complete(
11✔
2107
                self._call_in_runner_task(asyncgen.asend, None)
2108
            )
2109
        except StopAsyncIteration:
11✔
2110
            self._raise_async_exceptions()
11✔
2111
        else:
2112
            self.get_loop().run_until_complete(asyncgen.aclose())
×
2113
            raise RuntimeError("Async generator fixture did not stop")
×
2114

2115
    def run_fixture(
11✔
2116
        self,
2117
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
2118
        kwargs: dict[str, Any],
2119
    ) -> T_Retval:
2120
        retval = self.get_loop().run_until_complete(
11✔
2121
            self._call_in_runner_task(fixture_func, **kwargs)
2122
        )
2123
        self._raise_async_exceptions()
11✔
2124
        return retval
11✔
2125

2126
    def run_test(
11✔
2127
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
2128
    ) -> None:
2129
        try:
11✔
2130
            self.get_loop().run_until_complete(
11✔
2131
                self._call_in_runner_task(test_func, **kwargs)
2132
            )
2133
        except Exception as exc:
11✔
2134
            self._exceptions.append(exc)
11✔
2135

2136
        self._raise_async_exceptions()
11✔
2137

2138

2139
class AsyncIOBackend(AsyncBackend):
11✔
2140
    @classmethod
11✔
2141
    def run(
11✔
2142
        cls,
2143
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2144
        args: tuple[Unpack[PosArgsT]],
2145
        kwargs: dict[str, Any],
2146
        options: dict[str, Any],
2147
    ) -> T_Retval:
2148
        @wraps(func)
11✔
2149
        async def wrapper() -> T_Retval:
11✔
2150
            task = cast(asyncio.Task, current_task())
11✔
2151
            task.set_name(get_callable_name(func))
11✔
2152
            _task_states[task] = TaskState(None, None)
11✔
2153

2154
            try:
11✔
2155
                return await func(*args)
11✔
2156
            finally:
2157
                del _task_states[task]
11✔
2158

2159
        debug = options.get("debug", None)
11✔
2160
        loop_factory = options.get("loop_factory", None)
11✔
2161
        if loop_factory is None and options.get("use_uvloop", False):
11✔
2162
            import uvloop
7✔
2163

2164
            loop_factory = uvloop.new_event_loop
7✔
2165

2166
        with Runner(debug=debug, loop_factory=loop_factory) as runner:
11✔
2167
            return runner.run(wrapper())
11✔
2168

2169
    @classmethod
11✔
2170
    def current_token(cls) -> object:
11✔
2171
        return get_running_loop()
11✔
2172

2173
    @classmethod
11✔
2174
    def current_time(cls) -> float:
11✔
2175
        return get_running_loop().time()
11✔
2176

2177
    @classmethod
11✔
2178
    def cancelled_exception_class(cls) -> type[BaseException]:
11✔
2179
        return CancelledError
11✔
2180

2181
    @classmethod
11✔
2182
    async def checkpoint(cls) -> None:
11✔
2183
        await sleep(0)
11✔
2184

2185
    @classmethod
11✔
2186
    async def checkpoint_if_cancelled(cls) -> None:
11✔
2187
        task = current_task()
11✔
2188
        if task is None:
11✔
2189
            return
×
2190

2191
        try:
11✔
2192
            cancel_scope = _task_states[task].cancel_scope
11✔
2193
        except KeyError:
11✔
2194
            return
11✔
2195

2196
        while cancel_scope:
11✔
2197
            if cancel_scope.cancel_called:
11✔
2198
                await sleep(0)
11✔
2199
            elif cancel_scope.shield:
11✔
2200
                break
10✔
2201
            else:
2202
                cancel_scope = cancel_scope._parent_scope
11✔
2203

2204
    @classmethod
11✔
2205
    async def cancel_shielded_checkpoint(cls) -> None:
11✔
2206
        with CancelScope(shield=True):
11✔
2207
            await sleep(0)
11✔
2208

2209
    @classmethod
11✔
2210
    async def sleep(cls, delay: float) -> None:
11✔
2211
        await sleep(delay)
11✔
2212

2213
    @classmethod
11✔
2214
    def create_cancel_scope(
11✔
2215
        cls, *, deadline: float = math.inf, shield: bool = False
2216
    ) -> CancelScope:
2217
        return CancelScope(deadline=deadline, shield=shield)
11✔
2218

2219
    @classmethod
11✔
2220
    def current_effective_deadline(cls) -> float:
11✔
2221
        try:
10✔
2222
            cancel_scope = _task_states[
10✔
2223
                current_task()  # type: ignore[index]
2224
            ].cancel_scope
2225
        except KeyError:
×
2226
            return math.inf
×
2227

2228
        deadline = math.inf
10✔
2229
        while cancel_scope:
10✔
2230
            deadline = min(deadline, cancel_scope.deadline)
10✔
2231
            if cancel_scope._cancel_called:
10✔
2232
                deadline = -math.inf
10✔
2233
                break
10✔
2234
            elif cancel_scope.shield:
10✔
2235
                break
10✔
2236
            else:
2237
                cancel_scope = cancel_scope._parent_scope
10✔
2238

2239
        return deadline
10✔
2240

2241
    @classmethod
11✔
2242
    def create_task_group(cls) -> abc.TaskGroup:
11✔
2243
        return TaskGroup()
11✔
2244

2245
    @classmethod
11✔
2246
    def create_event(cls) -> abc.Event:
11✔
2247
        return Event()
11✔
2248

2249
    @classmethod
11✔
2250
    def create_lock(cls) -> abc.Lock:
11✔
2251
        return Lock()
10✔
2252

2253
    @classmethod
11✔
2254
    def create_semaphore(
11✔
2255
        cls, initial_value: int, *, max_value: int | None = None
2256
    ) -> abc.Semaphore:
2257
        return Semaphore(initial_value, max_value=max_value)
10✔
2258

2259
    @classmethod
11✔
2260
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
11✔
2261
        return CapacityLimiter(total_tokens)
10✔
2262

2263
    @classmethod
11✔
2264
    async def run_sync_in_worker_thread(
11✔
2265
        cls,
2266
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2267
        args: tuple[Unpack[PosArgsT]],
2268
        abandon_on_cancel: bool = False,
2269
        limiter: abc.CapacityLimiter | None = None,
2270
    ) -> T_Retval:
2271
        await cls.checkpoint()
11✔
2272

2273
        # If this is the first run in this event loop thread, set up the necessary
2274
        # variables
2275
        try:
11✔
2276
            idle_workers = _threadpool_idle_workers.get()
11✔
2277
            workers = _threadpool_workers.get()
11✔
2278
        except LookupError:
11✔
2279
            idle_workers = deque()
11✔
2280
            workers = set()
11✔
2281
            _threadpool_idle_workers.set(idle_workers)
11✔
2282
            _threadpool_workers.set(workers)
11✔
2283

2284
        async with limiter or cls.current_default_thread_limiter():
11✔
2285
            with CancelScope(shield=not abandon_on_cancel) as scope:
11✔
2286
                future: asyncio.Future = asyncio.Future()
11✔
2287
                root_task = find_root_task()
11✔
2288
                if not idle_workers:
11✔
2289
                    worker = WorkerThread(root_task, workers, idle_workers)
11✔
2290
                    worker.start()
11✔
2291
                    workers.add(worker)
11✔
2292
                    root_task.add_done_callback(worker.stop)
11✔
2293
                else:
2294
                    worker = idle_workers.pop()
11✔
2295

2296
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2297
                    # seconds or longer
2298
                    now = cls.current_time()
11✔
2299
                    while idle_workers:
11✔
2300
                        if (
9✔
2301
                            now - idle_workers[0].idle_since
2302
                            < WorkerThread.MAX_IDLE_TIME
2303
                        ):
2304
                            break
9✔
2305

2306
                        expired_worker = idle_workers.popleft()
×
2307
                        expired_worker.root_task.remove_done_callback(
×
2308
                            expired_worker.stop
2309
                        )
2310
                        expired_worker.stop()
×
2311

2312
                context = copy_context()
11✔
2313
                context.run(sniffio.current_async_library_cvar.set, None)
11✔
2314
                if abandon_on_cancel or scope._parent_scope is None:
11✔
2315
                    worker_scope = scope
11✔
2316
                else:
2317
                    worker_scope = scope._parent_scope
11✔
2318

2319
                worker.queue.put_nowait((context, func, args, future, worker_scope))
11✔
2320
                return await future
11✔
2321

2322
    @classmethod
11✔
2323
    def check_cancelled(cls) -> None:
11✔
2324
        scope: CancelScope | None = threadlocals.current_cancel_scope
11✔
2325
        while scope is not None:
11✔
2326
            if scope.cancel_called:
11✔
2327
                raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
11✔
2328

2329
            if scope.shield:
11✔
2330
                return
×
2331

2332
            scope = scope._parent_scope
11✔
2333

2334
    @classmethod
11✔
2335
    def run_async_from_thread(
11✔
2336
        cls,
2337
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2338
        args: tuple[Unpack[PosArgsT]],
2339
        token: object,
2340
    ) -> T_Retval:
2341
        async def task_wrapper(scope: CancelScope) -> T_Retval:
11✔
2342
            __tracebackhide__ = True
11✔
2343
            task = cast(asyncio.Task, current_task())
11✔
2344
            _task_states[task] = TaskState(None, scope)
11✔
2345
            scope._tasks.add(task)
11✔
2346
            try:
11✔
2347
                return await func(*args)
11✔
2348
            except CancelledError as exc:
11✔
2349
                raise concurrent.futures.CancelledError(str(exc)) from None
11✔
2350
            finally:
2351
                scope._tasks.discard(task)
11✔
2352

2353
        loop = cast(AbstractEventLoop, token)
11✔
2354
        context = copy_context()
11✔
2355
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
11✔
2356
        wrapper = task_wrapper(threadlocals.current_cancel_scope)
11✔
2357
        f: concurrent.futures.Future[T_Retval] = context.run(
11✔
2358
            asyncio.run_coroutine_threadsafe, wrapper, loop
2359
        )
2360
        return f.result()
11✔
2361

2362
    @classmethod
11✔
2363
    def run_sync_from_thread(
11✔
2364
        cls,
2365
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2366
        args: tuple[Unpack[PosArgsT]],
2367
        token: object,
2368
    ) -> T_Retval:
2369
        @wraps(func)
11✔
2370
        def wrapper() -> None:
11✔
2371
            try:
11✔
2372
                sniffio.current_async_library_cvar.set("asyncio")
11✔
2373
                f.set_result(func(*args))
11✔
2374
            except BaseException as exc:
11✔
2375
                f.set_exception(exc)
11✔
2376
                if not isinstance(exc, Exception):
11✔
2377
                    raise
×
2378

2379
        f: concurrent.futures.Future[T_Retval] = Future()
11✔
2380
        loop = cast(AbstractEventLoop, token)
11✔
2381
        loop.call_soon_threadsafe(wrapper)
11✔
2382
        return f.result()
11✔
2383

2384
    @classmethod
11✔
2385
    def create_blocking_portal(cls) -> abc.BlockingPortal:
11✔
2386
        return BlockingPortal()
11✔
2387

2388
    @classmethod
11✔
2389
    async def open_process(
11✔
2390
        cls,
2391
        command: str | bytes | Sequence[str | bytes],
2392
        *,
2393
        shell: bool,
2394
        stdin: int | IO[Any] | None,
2395
        stdout: int | IO[Any] | None,
2396
        stderr: int | IO[Any] | None,
2397
        cwd: str | bytes | PathLike | None = None,
2398
        env: Mapping[str, str] | None = None,
2399
        start_new_session: bool = False,
2400
    ) -> Process:
2401
        await cls.checkpoint()
10✔
2402
        if shell:
10✔
2403
            process = await asyncio.create_subprocess_shell(
10✔
2404
                cast("str | bytes", command),
2405
                stdin=stdin,
2406
                stdout=stdout,
2407
                stderr=stderr,
2408
                cwd=cwd,
2409
                env=env,
2410
                start_new_session=start_new_session,
2411
            )
2412
        else:
2413
            process = await asyncio.create_subprocess_exec(
10✔
2414
                *command,
2415
                stdin=stdin,
2416
                stdout=stdout,
2417
                stderr=stderr,
2418
                cwd=cwd,
2419
                env=env,
2420
                start_new_session=start_new_session,
2421
            )
2422

2423
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
10✔
2424
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
10✔
2425
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
10✔
2426
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
10✔
2427

2428
    @classmethod
11✔
2429
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
11✔
2430
        create_task(
10✔
2431
            _shutdown_process_pool_on_exit(workers),
2432
            name="AnyIO process pool shutdown task",
2433
        )
2434
        find_root_task().add_done_callback(
10✔
2435
            partial(_forcibly_shutdown_process_pool_on_exit, workers)  # type:ignore[arg-type]
2436
        )
2437

2438
    @classmethod
11✔
2439
    async def connect_tcp(
11✔
2440
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2441
    ) -> abc.SocketStream:
2442
        transport, protocol = cast(
11✔
2443
            Tuple[asyncio.Transport, StreamProtocol],
2444
            await get_running_loop().create_connection(
2445
                StreamProtocol, host, port, local_addr=local_address
2446
            ),
2447
        )
2448
        transport.pause_reading()
11✔
2449
        return SocketStream(transport, protocol)
11✔
2450

2451
    @classmethod
11✔
2452
    async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
11✔
2453
        await cls.checkpoint()
8✔
2454
        loop = get_running_loop()
8✔
2455
        raw_socket = socket.socket(socket.AF_UNIX)
8✔
2456
        raw_socket.setblocking(False)
8✔
2457
        while True:
5✔
2458
            try:
8✔
2459
                raw_socket.connect(path)
8✔
2460
            except BlockingIOError:
8✔
2461
                f: asyncio.Future = asyncio.Future()
×
2462
                loop.add_writer(raw_socket, f.set_result, None)
×
2463
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2464
                await f
×
2465
            except BaseException:
8✔
2466
                raw_socket.close()
8✔
2467
                raise
8✔
2468
            else:
2469
                return UNIXSocketStream(raw_socket)
8✔
2470

2471
    @classmethod
11✔
2472
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
11✔
2473
        return TCPSocketListener(sock)
11✔
2474

2475
    @classmethod
11✔
2476
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
11✔
2477
        return UNIXSocketListener(sock)
8✔
2478

2479
    @classmethod
11✔
2480
    async def create_udp_socket(
11✔
2481
        cls,
2482
        family: AddressFamily,
2483
        local_address: IPSockAddrType | None,
2484
        remote_address: IPSockAddrType | None,
2485
        reuse_port: bool,
2486
    ) -> UDPSocket | ConnectedUDPSocket:
2487
        transport, protocol = await get_running_loop().create_datagram_endpoint(
10✔
2488
            DatagramProtocol,
2489
            local_addr=local_address,
2490
            remote_addr=remote_address,
2491
            family=family,
2492
            reuse_port=reuse_port,
2493
        )
2494
        if protocol.exception:
10✔
2495
            transport.close()
×
2496
            raise protocol.exception
×
2497

2498
        if not remote_address:
10✔
2499
            return UDPSocket(transport, protocol)
10✔
2500
        else:
2501
            return ConnectedUDPSocket(transport, protocol)
10✔
2502

2503
    @classmethod
11✔
2504
    async def create_unix_datagram_socket(  # type: ignore[override]
11✔
2505
        cls, raw_socket: socket.socket, remote_path: str | bytes | None
2506
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2507
        await cls.checkpoint()
8✔
2508
        loop = get_running_loop()
8✔
2509

2510
        if remote_path:
8✔
2511
            while True:
5✔
2512
                try:
8✔
2513
                    raw_socket.connect(remote_path)
8✔
2514
                except BlockingIOError:
×
2515
                    f: asyncio.Future = asyncio.Future()
×
2516
                    loop.add_writer(raw_socket, f.set_result, None)
×
2517
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2518
                    await f
×
2519
                except BaseException:
×
2520
                    raw_socket.close()
×
2521
                    raise
×
2522
                else:
2523
                    return ConnectedUNIXDatagramSocket(raw_socket)
8✔
2524
        else:
2525
            return UNIXDatagramSocket(raw_socket)
8✔
2526

2527
    @classmethod
11✔
2528
    async def getaddrinfo(
11✔
2529
        cls,
2530
        host: bytes | str | None,
2531
        port: str | int | None,
2532
        *,
2533
        family: int | AddressFamily = 0,
2534
        type: int | SocketKind = 0,
2535
        proto: int = 0,
2536
        flags: int = 0,
2537
    ) -> list[
2538
        tuple[
2539
            AddressFamily,
2540
            SocketKind,
2541
            int,
2542
            str,
2543
            tuple[str, int] | tuple[str, int, int, int],
2544
        ]
2545
    ]:
2546
        return await get_running_loop().getaddrinfo(
11✔
2547
            host, port, family=family, type=type, proto=proto, flags=flags
2548
        )
2549

2550
    @classmethod
11✔
2551
    async def getnameinfo(
11✔
2552
        cls, sockaddr: IPSockAddrType, flags: int = 0
2553
    ) -> tuple[str, str]:
2554
        return await get_running_loop().getnameinfo(sockaddr, flags)
10✔
2555

2556
    @classmethod
11✔
2557
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
11✔
2558
        await cls.checkpoint()
×
2559
        try:
×
2560
            read_events = _read_events.get()
×
2561
        except LookupError:
×
2562
            read_events = {}
×
2563
            _read_events.set(read_events)
×
2564

2565
        if read_events.get(sock):
×
2566
            raise BusyResourceError("reading from") from None
×
2567

2568
        loop = get_running_loop()
×
2569
        event = read_events[sock] = asyncio.Event()
×
2570
        loop.add_reader(sock, event.set)
×
2571
        try:
×
2572
            await event.wait()
×
2573
        finally:
2574
            if read_events.pop(sock, None) is not None:
×
2575
                loop.remove_reader(sock)
×
2576
                readable = True
×
2577
            else:
2578
                readable = False
×
2579

2580
        if not readable:
×
2581
            raise ClosedResourceError
×
2582

2583
    @classmethod
11✔
2584
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
11✔
2585
        await cls.checkpoint()
×
2586
        try:
×
2587
            write_events = _write_events.get()
×
2588
        except LookupError:
×
2589
            write_events = {}
×
2590
            _write_events.set(write_events)
×
2591

2592
        if write_events.get(sock):
×
2593
            raise BusyResourceError("writing to") from None
×
2594

2595
        loop = get_running_loop()
×
2596
        event = write_events[sock] = asyncio.Event()
×
2597
        loop.add_writer(sock.fileno(), event.set)
×
2598
        try:
×
2599
            await event.wait()
×
2600
        finally:
2601
            if write_events.pop(sock, None) is not None:
×
2602
                loop.remove_writer(sock)
×
2603
                writable = True
×
2604
            else:
2605
                writable = False
×
2606

2607
        if not writable:
×
2608
            raise ClosedResourceError
×
2609

2610
    @classmethod
11✔
2611
    def current_default_thread_limiter(cls) -> CapacityLimiter:
11✔
2612
        try:
11✔
2613
            return _default_thread_limiter.get()
11✔
2614
        except LookupError:
11✔
2615
            limiter = CapacityLimiter(40)
11✔
2616
            _default_thread_limiter.set(limiter)
11✔
2617
            return limiter
11✔
2618

2619
    @classmethod
11✔
2620
    def open_signal_receiver(
11✔
2621
        cls, *signals: Signals
2622
    ) -> ContextManager[AsyncIterator[Signals]]:
2623
        return _SignalReceiver(signals)
9✔
2624

2625
    @classmethod
11✔
2626
    def get_current_task(cls) -> TaskInfo:
11✔
2627
        return AsyncIOTaskInfo(current_task())  # type: ignore[arg-type]
11✔
2628

2629
    @classmethod
11✔
2630
    def get_running_tasks(cls) -> Sequence[TaskInfo]:
11✔
2631
        return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
11✔
2632

2633
    @classmethod
11✔
2634
    async def wait_all_tasks_blocked(cls) -> None:
11✔
2635
        await cls.checkpoint()
11✔
2636
        this_task = current_task()
11✔
2637
        while True:
7✔
2638
            for task in all_tasks():
11✔
2639
                if task is this_task:
11✔
2640
                    continue
11✔
2641

2642
                waiter = task._fut_waiter  # type: ignore[attr-defined]
11✔
2643
                if waiter is None or waiter.done():
11✔
2644
                    await sleep(0.1)
11✔
2645
                    break
11✔
2646
            else:
2647
                return
11✔
2648

2649
    @classmethod
11✔
2650
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
11✔
2651
        return TestRunner(**options)
11✔
2652

2653

2654
backend_class = AsyncIOBackend
11✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc