• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 10716642356

05 Sep 2024 08:22AM UTC coverage: 91.7% (+0.005%) from 91.695%
10716642356

Pull #782

github

web-flow
Merge 8854bb22b into 0c8ad519e
Pull Request #782: Accept abstract namespace paths for unix domain sockets

9 of 10 new or added lines in 1 file covered. (90.0%)

5 existing lines in 1 file now uncovered.

4795 of 5229 relevant lines covered (91.7%)

9.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.7
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
11✔
2

3
import array
11✔
4
import asyncio
11✔
5
import concurrent.futures
11✔
6
import math
11✔
7
import os
11✔
8
import socket
11✔
9
import sys
11✔
10
import threading
11✔
11
import weakref
11✔
12
from asyncio import (
11✔
13
    AbstractEventLoop,
14
    CancelledError,
15
    all_tasks,
16
    create_task,
17
    current_task,
18
    get_running_loop,
19
    sleep,
20
)
21
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
11✔
22
from collections import OrderedDict, deque
11✔
23
from collections.abc import AsyncIterator, Iterable
11✔
24
from concurrent.futures import Future
11✔
25
from contextlib import suppress
11✔
26
from contextvars import Context, copy_context
11✔
27
from dataclasses import dataclass
11✔
28
from functools import partial, wraps
11✔
29
from inspect import (
11✔
30
    CORO_RUNNING,
31
    CORO_SUSPENDED,
32
    getcoroutinestate,
33
    iscoroutine,
34
)
35
from io import IOBase
11✔
36
from os import PathLike
11✔
37
from queue import Queue
11✔
38
from signal import Signals
11✔
39
from socket import AddressFamily, SocketKind
11✔
40
from threading import Thread
11✔
41
from types import TracebackType
11✔
42
from typing import (
11✔
43
    IO,
44
    Any,
45
    AsyncGenerator,
46
    Awaitable,
47
    Callable,
48
    Collection,
49
    ContextManager,
50
    Coroutine,
51
    Optional,
52
    Sequence,
53
    Tuple,
54
    TypeVar,
55
    cast,
56
)
57
from weakref import WeakKeyDictionary
11✔
58

59
import sniffio
11✔
60

61
from .. import (
11✔
62
    CapacityLimiterStatistics,
63
    EventStatistics,
64
    LockStatistics,
65
    TaskInfo,
66
    abc,
67
)
68
from .._core._eventloop import claim_worker_thread, threadlocals
11✔
69
from .._core._exceptions import (
11✔
70
    BrokenResourceError,
71
    BusyResourceError,
72
    ClosedResourceError,
73
    EndOfStream,
74
    WouldBlock,
75
    iterate_exceptions,
76
)
77
from .._core._sockets import convert_ipv6_sockaddr
11✔
78
from .._core._streams import create_memory_object_stream
11✔
79
from .._core._synchronization import (
11✔
80
    CapacityLimiter as BaseCapacityLimiter,
81
)
82
from .._core._synchronization import Event as BaseEvent
11✔
83
from .._core._synchronization import Lock as BaseLock
11✔
84
from .._core._synchronization import (
11✔
85
    ResourceGuard,
86
    SemaphoreStatistics,
87
)
88
from .._core._synchronization import Semaphore as BaseSemaphore
11✔
89
from .._core._tasks import CancelScope as BaseCancelScope
11✔
90
from ..abc import (
11✔
91
    AsyncBackend,
92
    IPSockAddrType,
93
    SocketListener,
94
    UDPPacketType,
95
    UNIXDatagramPacketType,
96
)
97
from ..abc._eventloop import StrOrBytesPath
11✔
98
from ..lowlevel import RunVar
11✔
99
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11✔
100

101
if sys.version_info >= (3, 10):
11✔
102
    from typing import ParamSpec
7✔
103
else:
104
    from typing_extensions import ParamSpec
4✔
105

106
if sys.version_info >= (3, 11):
11✔
107
    from asyncio import Runner
5✔
108
    from typing import TypeVarTuple, Unpack
5✔
109
else:
110
    import contextvars
6✔
111
    import enum
6✔
112
    import signal
6✔
113
    from asyncio import coroutines, events, exceptions, tasks
6✔
114

115
    from exceptiongroup import BaseExceptionGroup
6✔
116
    from typing_extensions import TypeVarTuple, Unpack
6✔
117

118
    class _State(enum.Enum):
6✔
119
        CREATED = "created"
6✔
120
        INITIALIZED = "initialized"
6✔
121
        CLOSED = "closed"
6✔
122

123
    class Runner:
6✔
124
        # Copied from CPython 3.11
125
        def __init__(
6✔
126
            self,
127
            *,
128
            debug: bool | None = None,
129
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
130
        ):
131
            self._state = _State.CREATED
6✔
132
            self._debug = debug
6✔
133
            self._loop_factory = loop_factory
6✔
134
            self._loop: AbstractEventLoop | None = None
6✔
135
            self._context = None
6✔
136
            self._interrupt_count = 0
6✔
137
            self._set_event_loop = False
6✔
138

139
        def __enter__(self) -> Runner:
6✔
140
            self._lazy_init()
6✔
141
            return self
6✔
142

143
        def __exit__(
6✔
144
            self,
145
            exc_type: type[BaseException],
146
            exc_val: BaseException,
147
            exc_tb: TracebackType,
148
        ) -> None:
149
            self.close()
6✔
150

151
        def close(self) -> None:
6✔
152
            """Shutdown and close event loop."""
153
            if self._state is not _State.INITIALIZED:
6✔
154
                return
×
155
            try:
6✔
156
                loop = self._loop
6✔
157
                _cancel_all_tasks(loop)
6✔
158
                loop.run_until_complete(loop.shutdown_asyncgens())
6✔
159
                if hasattr(loop, "shutdown_default_executor"):
6✔
160
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
161
                else:
162
                    loop.run_until_complete(_shutdown_default_executor(loop))
3✔
163
            finally:
164
                if self._set_event_loop:
6✔
165
                    events.set_event_loop(None)
6✔
166
                loop.close()
6✔
167
                self._loop = None
6✔
168
                self._state = _State.CLOSED
6✔
169

170
        def get_loop(self) -> AbstractEventLoop:
6✔
171
            """Return embedded event loop."""
172
            self._lazy_init()
6✔
173
            return self._loop
6✔
174

175
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
6✔
176
            """Run a coroutine inside the embedded event loop."""
177
            if not coroutines.iscoroutine(coro):
6✔
178
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
179

180
            if events._get_running_loop() is not None:
6✔
181
                # fail fast with short traceback
182
                raise RuntimeError(
×
183
                    "Runner.run() cannot be called from a running event loop"
184
                )
185

186
            self._lazy_init()
6✔
187

188
            if context is None:
6✔
189
                context = self._context
6✔
190
            task = context.run(self._loop.create_task, coro)
6✔
191

192
            if (
6✔
193
                threading.current_thread() is threading.main_thread()
194
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
195
            ):
196
                sigint_handler = partial(self._on_sigint, main_task=task)
6✔
197
                try:
6✔
198
                    signal.signal(signal.SIGINT, sigint_handler)
6✔
199
                except ValueError:
×
200
                    # `signal.signal` may throw if `threading.main_thread` does
201
                    # not support signals (e.g. embedded interpreter with signals
202
                    # not registered - see gh-91880)
203
                    sigint_handler = None
×
204
            else:
205
                sigint_handler = None
6✔
206

207
            self._interrupt_count = 0
6✔
208
            try:
6✔
209
                return self._loop.run_until_complete(task)
6✔
210
            except exceptions.CancelledError:
6✔
211
                if self._interrupt_count > 0:
×
212
                    uncancel = getattr(task, "uncancel", None)
×
213
                    if uncancel is not None and uncancel() == 0:
×
214
                        raise KeyboardInterrupt()
×
215
                raise  # CancelledError
×
216
            finally:
217
                if (
6✔
218
                    sigint_handler is not None
219
                    and signal.getsignal(signal.SIGINT) is sigint_handler
220
                ):
221
                    signal.signal(signal.SIGINT, signal.default_int_handler)
6✔
222

223
        def _lazy_init(self) -> None:
6✔
224
            if self._state is _State.CLOSED:
6✔
225
                raise RuntimeError("Runner is closed")
×
226
            if self._state is _State.INITIALIZED:
6✔
227
                return
6✔
228
            if self._loop_factory is None:
6✔
229
                self._loop = events.new_event_loop()
6✔
230
                if not self._set_event_loop:
6✔
231
                    # Call set_event_loop only once to avoid calling
232
                    # attach_loop multiple times on child watchers
233
                    events.set_event_loop(self._loop)
6✔
234
                    self._set_event_loop = True
6✔
235
            else:
236
                self._loop = self._loop_factory()
4✔
237
            if self._debug is not None:
6✔
238
                self._loop.set_debug(self._debug)
6✔
239
            self._context = contextvars.copy_context()
6✔
240
            self._state = _State.INITIALIZED
7✔
241

242
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
243
            self._interrupt_count += 1
×
244
            if self._interrupt_count == 1 and not main_task.done():
×
245
                main_task.cancel()
×
246
                # wakeup loop if it is blocked by select() with long timeout
247
                self._loop.call_soon_threadsafe(lambda: None)
×
248
                return
×
249
            raise KeyboardInterrupt()
×
250

251
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
6✔
252
        to_cancel = tasks.all_tasks(loop)
6✔
253
        if not to_cancel:
6✔
254
            return
6✔
255

256
        for task in to_cancel:
6✔
257
            task.cancel()
6✔
258

259
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
6✔
260

261
        for task in to_cancel:
6✔
262
            if task.cancelled():
6✔
263
                continue
6✔
264
            if task.exception() is not None:
5✔
265
                loop.call_exception_handler(
×
266
                    {
267
                        "message": "unhandled exception during asyncio.run() shutdown",
268
                        "exception": task.exception(),
269
                        "task": task,
270
                    }
271
                )
272

273
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
6✔
274
        """Schedule the shutdown of the default executor."""
275

276
        def _do_shutdown(future: asyncio.futures.Future) -> None:
3✔
277
            try:
3✔
278
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
3✔
279
                loop.call_soon_threadsafe(future.set_result, None)
3✔
280
            except Exception as ex:
×
281
                loop.call_soon_threadsafe(future.set_exception, ex)
×
282

283
        loop._executor_shutdown_called = True
3✔
284
        if loop._default_executor is None:
3✔
285
            return
3✔
286
        future = loop.create_future()
3✔
287
        thread = threading.Thread(target=_do_shutdown, args=(future,))
3✔
288
        thread.start()
3✔
289
        try:
3✔
290
            await future
3✔
291
        finally:
292
            thread.join()
3✔
293

294

295
T_Retval = TypeVar("T_Retval")
11✔
296
T_contra = TypeVar("T_contra", contravariant=True)
11✔
297
PosArgsT = TypeVarTuple("PosArgsT")
11✔
298
P = ParamSpec("P")
11✔
299

300
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
11✔
301

302

303
def find_root_task() -> asyncio.Task:
11✔
304
    root_task = _root_task.get(None)
11✔
305
    if root_task is not None and not root_task.done():
11✔
306
        return root_task
11✔
307

308
    # Look for a task that has been started via run_until_complete()
309
    for task in all_tasks():
11✔
310
        if task._callbacks and not task.done():
11✔
311
            callbacks = [cb for cb, context in task._callbacks]
11✔
312
            for cb in callbacks:
11✔
313
                if (
11✔
314
                    cb is _run_until_complete_cb
315
                    or getattr(cb, "__module__", None) == "uvloop.loop"
316
                ):
317
                    _root_task.set(task)
11✔
318
                    return task
11✔
319

320
    # Look up the topmost task in the AnyIO task tree, if possible
321
    task = cast(asyncio.Task, current_task())
10✔
322
    state = _task_states.get(task)
10✔
323
    if state:
10✔
324
        cancel_scope = state.cancel_scope
10✔
325
        while cancel_scope and cancel_scope._parent_scope is not None:
10✔
326
            cancel_scope = cancel_scope._parent_scope
×
327

328
        if cancel_scope is not None:
10✔
329
            return cast(asyncio.Task, cancel_scope._host_task)
10✔
330

331
    return task
×
332

333

334
def get_callable_name(func: Callable) -> str:
11✔
335
    module = getattr(func, "__module__", None)
11✔
336
    qualname = getattr(func, "__qualname__", None)
11✔
337
    return ".".join([x for x in (module, qualname) if x])
11✔
338

339

340
#
341
# Event loop
342
#
343

344
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
11✔
345

346

347
def _task_started(task: asyncio.Task) -> bool:
11✔
348
    """Return ``True`` if the task has been started and has not finished."""
349
    try:
11✔
350
        return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
11✔
351
    except AttributeError:
×
352
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
353
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
354

355

356
#
357
# Timeouts and cancellation
358
#
359

360

361
class CancelScope(BaseCancelScope):
11✔
362
    def __new__(
11✔
363
        cls, *, deadline: float = math.inf, shield: bool = False
364
    ) -> CancelScope:
365
        return object.__new__(cls)
11✔
366

367
    def __init__(self, deadline: float = math.inf, shield: bool = False):
11✔
368
        self._deadline = deadline
11✔
369
        self._shield = shield
11✔
370
        self._parent_scope: CancelScope | None = None
11✔
371
        self._child_scopes: set[CancelScope] = set()
11✔
372
        self._cancel_called = False
11✔
373
        self._cancelled_caught = False
11✔
374
        self._active = False
11✔
375
        self._timeout_handle: asyncio.TimerHandle | None = None
11✔
376
        self._cancel_handle: asyncio.Handle | None = None
11✔
377
        self._tasks: set[asyncio.Task] = set()
11✔
378
        self._host_task: asyncio.Task | None = None
11✔
379
        self._cancel_calls: int = 0
11✔
380
        self._cancelling: int | None = None
11✔
381

382
    def __enter__(self) -> CancelScope:
11✔
383
        if self._active:
11✔
384
            raise RuntimeError(
×
385
                "Each CancelScope may only be used for a single 'with' block"
386
            )
387

388
        self._host_task = host_task = cast(asyncio.Task, current_task())
11✔
389
        self._tasks.add(host_task)
11✔
390
        try:
11✔
391
            task_state = _task_states[host_task]
11✔
392
        except KeyError:
11✔
393
            task_state = TaskState(None, self)
11✔
394
            _task_states[host_task] = task_state
11✔
395
        else:
396
            self._parent_scope = task_state.cancel_scope
11✔
397
            task_state.cancel_scope = self
11✔
398
            if self._parent_scope is not None:
11✔
399
                self._parent_scope._child_scopes.add(self)
11✔
400
                self._parent_scope._tasks.remove(host_task)
11✔
401

402
        self._timeout()
11✔
403
        self._active = True
11✔
404
        if sys.version_info >= (3, 11):
11✔
405
            self._cancelling = self._host_task.cancelling()
5✔
406

407
        # Start cancelling the host task if the scope was cancelled before entering
408
        if self._cancel_called:
11✔
409
            self._deliver_cancellation(self)
11✔
410

411
        return self
11✔
412

413
    def __exit__(
11✔
414
        self,
415
        exc_type: type[BaseException] | None,
416
        exc_val: BaseException | None,
417
        exc_tb: TracebackType | None,
418
    ) -> bool | None:
419
        if not self._active:
11✔
420
            raise RuntimeError("This cancel scope is not active")
10✔
421
        if current_task() is not self._host_task:
11✔
422
            raise RuntimeError(
10✔
423
                "Attempted to exit cancel scope in a different task than it was "
424
                "entered in"
425
            )
426

427
        assert self._host_task is not None
11✔
428
        host_task_state = _task_states.get(self._host_task)
11✔
429
        if host_task_state is None or host_task_state.cancel_scope is not self:
11✔
430
            raise RuntimeError(
10✔
431
                "Attempted to exit a cancel scope that isn't the current tasks's "
432
                "current cancel scope"
433
            )
434

435
        self._active = False
11✔
436
        if self._timeout_handle:
11✔
437
            self._timeout_handle.cancel()
11✔
438
            self._timeout_handle = None
11✔
439

440
        self._tasks.remove(self._host_task)
11✔
441
        if self._parent_scope is not None:
11✔
442
            self._parent_scope._child_scopes.remove(self)
11✔
443
            self._parent_scope._tasks.add(self._host_task)
11✔
444

445
        host_task_state.cancel_scope = self._parent_scope
11✔
446

447
        # Restart the cancellation effort in the closest directly cancelled parent
448
        # scope if this one was shielded
449
        self._restart_cancellation_in_parent()
11✔
450

451
        if self._cancel_called and exc_val is not None:
11✔
452
            for exc in iterate_exceptions(exc_val):
11✔
453
                if isinstance(exc, CancelledError):
11✔
454
                    self._cancelled_caught = self._uncancel(exc)
11✔
455
                    if self._cancelled_caught:
11✔
456
                        break
11✔
457

458
            return self._cancelled_caught
11✔
459

460
        return None
11✔
461

462
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
11✔
463
        if sys.version_info < (3, 9) or self._host_task is None:
11✔
464
            self._cancel_calls = 0
3✔
465
            return True
3✔
466

467
        # Undo all cancellations done by this scope
468
        if self._cancelling is not None:
8✔
469
            while self._cancel_calls:
5✔
470
                self._cancel_calls -= 1
5✔
471
                if self._host_task.uncancel() <= self._cancelling:
5✔
472
                    return True
5✔
473

474
        self._cancel_calls = 0
8✔
475
        return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
8✔
476

477
    def _timeout(self) -> None:
11✔
478
        if self._deadline != math.inf:
11✔
479
            loop = get_running_loop()
11✔
480
            if loop.time() >= self._deadline:
11✔
481
                self.cancel()
11✔
482
            else:
483
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
11✔
484

485
    def _deliver_cancellation(self, origin: CancelScope) -> bool:
11✔
486
        """
487
        Deliver cancellation to directly contained tasks and nested cancel scopes.
488

489
        Schedule another run at the end if we still have tasks eligible for
490
        cancellation.
491

492
        :param origin: the cancel scope that originated the cancellation
493
        :return: ``True`` if the delivery needs to be retried on the next cycle
494

495
        """
496
        should_retry = False
11✔
497
        current = current_task()
11✔
498
        for task in self._tasks:
11✔
499
            if task._must_cancel:  # type: ignore[attr-defined]
11✔
500
                continue
10✔
501

502
            # The task is eligible for cancellation if it has started
503
            should_retry = True
11✔
504
            if task is not current and (task is self._host_task or _task_started(task)):
11✔
505
                waiter = task._fut_waiter  # type: ignore[attr-defined]
11✔
506
                if not isinstance(waiter, asyncio.Future) or not waiter.done():
11✔
507
                    origin._cancel_calls += 1
11✔
508
                    if sys.version_info >= (3, 9):
11✔
509
                        task.cancel(f"Cancelled by cancel scope {id(origin):x}")
8✔
510
                    else:
511
                        task.cancel()
3✔
512

513
        # Deliver cancellation to child scopes that aren't shielded or running their own
514
        # cancellation callbacks
515
        for scope in self._child_scopes:
11✔
516
            if not scope._shield and not scope.cancel_called:
11✔
517
                should_retry = scope._deliver_cancellation(origin) or should_retry
11✔
518

519
        # Schedule another callback if there are still tasks left
520
        if origin is self:
11✔
521
            if should_retry:
11✔
522
                self._cancel_handle = get_running_loop().call_soon(
11✔
523
                    self._deliver_cancellation, origin
524
                )
525
            else:
526
                self._cancel_handle = None
11✔
527

528
        return should_retry
11✔
529

530
    def _restart_cancellation_in_parent(self) -> None:
11✔
531
        """
532
        Restart the cancellation effort in the closest directly cancelled parent scope.
533

534
        """
535
        scope = self._parent_scope
11✔
536
        while scope is not None:
11✔
537
            if scope._cancel_called:
11✔
538
                if scope._cancel_handle is None:
11✔
539
                    scope._deliver_cancellation(scope)
11✔
540

541
                break
11✔
542

543
            # No point in looking beyond any shielded scope
544
            if scope._shield:
11✔
545
                break
10✔
546

547
            scope = scope._parent_scope
11✔
548

549
    def _parent_cancelled(self) -> bool:
11✔
550
        # Check whether any parent has been cancelled
551
        cancel_scope = self._parent_scope
11✔
552
        while cancel_scope is not None and not cancel_scope._shield:
11✔
553
            if cancel_scope._cancel_called:
11✔
554
                return True
11✔
555
            else:
556
                cancel_scope = cancel_scope._parent_scope
11✔
557

558
        return False
11✔
559

560
    def cancel(self) -> None:
11✔
561
        if not self._cancel_called:
11✔
562
            if self._timeout_handle:
11✔
563
                self._timeout_handle.cancel()
11✔
564
                self._timeout_handle = None
11✔
565

566
            self._cancel_called = True
11✔
567
            if self._host_task is not None:
11✔
568
                self._deliver_cancellation(self)
11✔
569

570
    @property
11✔
571
    def deadline(self) -> float:
11✔
572
        return self._deadline
10✔
573

574
    @deadline.setter
11✔
575
    def deadline(self, value: float) -> None:
11✔
576
        self._deadline = float(value)
10✔
577
        if self._timeout_handle is not None:
10✔
578
            self._timeout_handle.cancel()
10✔
579
            self._timeout_handle = None
10✔
580

581
        if self._active and not self._cancel_called:
10✔
582
            self._timeout()
10✔
583

584
    @property
11✔
585
    def cancel_called(self) -> bool:
11✔
586
        return self._cancel_called
11✔
587

588
    @property
11✔
589
    def cancelled_caught(self) -> bool:
11✔
590
        return self._cancelled_caught
11✔
591

592
    @property
11✔
593
    def shield(self) -> bool:
11✔
594
        return self._shield
11✔
595

596
    @shield.setter
11✔
597
    def shield(self, value: bool) -> None:
11✔
598
        if self._shield != value:
10✔
599
            self._shield = value
10✔
600
            if not value:
10✔
601
                self._restart_cancellation_in_parent()
10✔
602

603

604
#
605
# Task states
606
#
607

608

609
class TaskState:
11✔
610
    """
611
    Encapsulates auxiliary task information that cannot be added to the Task instance
612
    itself because there are no guarantees about its implementation.
613
    """
614

615
    __slots__ = "parent_id", "cancel_scope", "__weakref__"
11✔
616

617
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
11✔
618
        self.parent_id = parent_id
11✔
619
        self.cancel_scope = cancel_scope
11✔
620

621

622
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
11✔
623

624

625
#
626
# Task groups
627
#
628

629

630
class _AsyncioTaskStatus(abc.TaskStatus):
11✔
631
    def __init__(self, future: asyncio.Future, parent_id: int):
11✔
632
        self._future = future
11✔
633
        self._parent_id = parent_id
11✔
634

635
    def started(self, value: T_contra | None = None) -> None:
11✔
636
        try:
11✔
637
            self._future.set_result(value)
11✔
638
        except asyncio.InvalidStateError:
10✔
639
            if not self._future.cancelled():
10✔
640
                raise RuntimeError(
10✔
641
                    "called 'started' twice on the same task status"
642
                ) from None
643

644
        task = cast(asyncio.Task, current_task())
11✔
645
        _task_states[task].parent_id = self._parent_id
11✔
646

647

648
class TaskGroup(abc.TaskGroup):
11✔
649
    def __init__(self) -> None:
11✔
650
        self.cancel_scope: CancelScope = CancelScope()
11✔
651
        self._active = False
11✔
652
        self._exceptions: list[BaseException] = []
11✔
653
        self._tasks: set[asyncio.Task] = set()
11✔
654

655
    async def __aenter__(self) -> TaskGroup:
11✔
656
        self.cancel_scope.__enter__()
11✔
657
        self._active = True
11✔
658
        return self
11✔
659

660
    async def __aexit__(
11✔
661
        self,
662
        exc_type: type[BaseException] | None,
663
        exc_val: BaseException | None,
664
        exc_tb: TracebackType | None,
665
    ) -> bool | None:
666
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
11✔
667
        if exc_val is not None:
11✔
668
            self.cancel_scope.cancel()
11✔
669
            if not isinstance(exc_val, CancelledError):
11✔
670
                self._exceptions.append(exc_val)
11✔
671

672
        cancelled_exc_while_waiting_tasks: CancelledError | None = None
11✔
673
        while self._tasks:
11✔
674
            try:
11✔
675
                await asyncio.wait(self._tasks)
11✔
676
            except CancelledError as exc:
11✔
677
                # This task was cancelled natively; reraise the CancelledError later
678
                # unless this task was already interrupted by another exception
679
                self.cancel_scope.cancel()
11✔
680
                if cancelled_exc_while_waiting_tasks is None:
11✔
681
                    cancelled_exc_while_waiting_tasks = exc
11✔
682

683
        self._active = False
11✔
684
        if self._exceptions:
11✔
685
            raise BaseExceptionGroup(
11✔
686
                "unhandled errors in a TaskGroup", self._exceptions
687
            )
688

689
        # Raise the CancelledError received while waiting for child tasks to exit,
690
        # unless the context manager itself was previously exited with another
691
        # exception, or if any of the  child tasks raised an exception other than
692
        # CancelledError
693
        if cancelled_exc_while_waiting_tasks:
11✔
694
            if exc_val is None or ignore_exception:
11✔
695
                raise cancelled_exc_while_waiting_tasks
11✔
696

697
        return ignore_exception
11✔
698

699
    def _spawn(
11✔
700
        self,
701
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
702
        args: tuple[Unpack[PosArgsT]],
703
        name: object,
704
        task_status_future: asyncio.Future | None = None,
705
    ) -> asyncio.Task:
706
        def task_done(_task: asyncio.Task) -> None:
11✔
707
            task_state = _task_states[_task]
11✔
708
            assert task_state.cancel_scope is not None
11✔
709
            assert _task in task_state.cancel_scope._tasks
11✔
710
            task_state.cancel_scope._tasks.remove(_task)
11✔
711
            self._tasks.remove(task)
11✔
712
            del _task_states[_task]
11✔
713

714
            try:
11✔
715
                exc = _task.exception()
11✔
716
            except CancelledError as e:
11✔
717
                while isinstance(e.__context__, CancelledError):
11✔
718
                    e = e.__context__
3✔
719

720
                exc = e
11✔
721

722
            if exc is not None:
11✔
723
                # The future can only be in the cancelled state if the host task was
724
                # cancelled, so return immediately instead of adding one more
725
                # CancelledError to the exceptions list
726
                if task_status_future is not None and task_status_future.cancelled():
11✔
727
                    return
10✔
728

729
                if task_status_future is None or task_status_future.done():
11✔
730
                    if not isinstance(exc, CancelledError):
11✔
731
                        self._exceptions.append(exc)
11✔
732

733
                    if not self.cancel_scope._parent_cancelled():
11✔
734
                        self.cancel_scope.cancel()
11✔
735
                else:
736
                    task_status_future.set_exception(exc)
10✔
737
            elif task_status_future is not None and not task_status_future.done():
11✔
738
                task_status_future.set_exception(
10✔
739
                    RuntimeError("Child exited without calling task_status.started()")
740
                )
741

742
        if not self._active:
11✔
743
            raise RuntimeError(
10✔
744
                "This task group is not active; no new tasks can be started."
745
            )
746

747
        kwargs = {}
11✔
748
        if task_status_future:
11✔
749
            parent_id = id(current_task())
11✔
750
            kwargs["task_status"] = _AsyncioTaskStatus(
11✔
751
                task_status_future, id(self.cancel_scope._host_task)
752
            )
753
        else:
754
            parent_id = id(self.cancel_scope._host_task)
11✔
755

756
        coro = func(*args, **kwargs)
11✔
757
        if not iscoroutine(coro):
11✔
758
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
10✔
759
            raise TypeError(
10✔
760
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
761
                f"the return value ({coro!r}) is not a coroutine object"
762
            )
763

764
        name = get_callable_name(func) if name is None else str(name)
11✔
765
        task = create_task(coro, name=name)
11✔
766
        task.add_done_callback(task_done)
11✔
767

768
        # Make the spawned task inherit the task group's cancel scope
769
        _task_states[task] = TaskState(
11✔
770
            parent_id=parent_id, cancel_scope=self.cancel_scope
771
        )
772
        self.cancel_scope._tasks.add(task)
11✔
773
        self._tasks.add(task)
11✔
774
        return task
11✔
775

776
    def start_soon(
11✔
777
        self,
778
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
779
        *args: Unpack[PosArgsT],
780
        name: object = None,
781
    ) -> None:
782
        self._spawn(func, args, name)
11✔
783

784
    async def start(
11✔
785
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
786
    ) -> Any:
787
        future: asyncio.Future = asyncio.Future()
11✔
788
        task = self._spawn(func, args, name, future)
11✔
789

790
        # If the task raises an exception after sending a start value without a switch
791
        # point between, the task group is cancelled and this method never proceeds to
792
        # process the completed future. That's why we have to have a shielded cancel
793
        # scope here.
794
        try:
11✔
795
            return await future
11✔
796
        except CancelledError:
10✔
797
            # Cancel the task and wait for it to exit before returning
798
            task.cancel()
10✔
799
            with CancelScope(shield=True), suppress(CancelledError):
10✔
800
                await task
10✔
801

802
            raise
10✔
803

804

805
#
806
# Threads
807
#
808

809
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
11✔
810

811

812
class WorkerThread(Thread):
11✔
813
    MAX_IDLE_TIME = 10  # seconds
11✔
814

815
    def __init__(
11✔
816
        self,
817
        root_task: asyncio.Task,
818
        workers: set[WorkerThread],
819
        idle_workers: deque[WorkerThread],
820
    ):
821
        super().__init__(name="AnyIO worker thread")
11✔
822
        self.root_task = root_task
11✔
823
        self.workers = workers
11✔
824
        self.idle_workers = idle_workers
11✔
825
        self.loop = root_task._loop
11✔
826
        self.queue: Queue[
11✔
827
            tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
828
        ] = Queue(2)
829
        self.idle_since = AsyncIOBackend.current_time()
11✔
830
        self.stopping = False
11✔
831

832
    def _report_result(
11✔
833
        self, future: asyncio.Future, result: Any, exc: BaseException | None
834
    ) -> None:
835
        self.idle_since = AsyncIOBackend.current_time()
11✔
836
        if not self.stopping:
11✔
837
            self.idle_workers.append(self)
11✔
838

839
        if not future.cancelled():
11✔
840
            if exc is not None:
11✔
841
                if isinstance(exc, StopIteration):
11✔
842
                    new_exc = RuntimeError("coroutine raised StopIteration")
10✔
843
                    new_exc.__cause__ = exc
10✔
844
                    exc = new_exc
10✔
845

846
                future.set_exception(exc)
11✔
847
            else:
848
                future.set_result(result)
11✔
849

850
    def run(self) -> None:
11✔
851
        with claim_worker_thread(AsyncIOBackend, self.loop):
11✔
852
            while True:
7✔
853
                item = self.queue.get()
11✔
854
                if item is None:
11✔
855
                    # Shutdown command received
856
                    return
11✔
857

858
                context, func, args, future, cancel_scope = item
11✔
859
                if not future.cancelled():
11✔
860
                    result = None
11✔
861
                    exception: BaseException | None = None
11✔
862
                    threadlocals.current_cancel_scope = cancel_scope
11✔
863
                    try:
11✔
864
                        result = context.run(func, *args)
11✔
865
                    except BaseException as exc:
11✔
866
                        exception = exc
11✔
867
                    finally:
868
                        del threadlocals.current_cancel_scope
11✔
869

870
                    if not self.loop.is_closed():
11✔
871
                        self.loop.call_soon_threadsafe(
11✔
872
                            self._report_result, future, result, exception
873
                        )
874

875
                self.queue.task_done()
11✔
876

877
    def stop(self, f: asyncio.Task | None = None) -> None:
11✔
878
        self.stopping = True
11✔
879
        self.queue.put_nowait(None)
11✔
880
        self.workers.discard(self)
11✔
881
        try:
11✔
882
            self.idle_workers.remove(self)
11✔
883
        except ValueError:
10✔
884
            pass
10✔
885

886

887
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
11✔
888
    "_threadpool_idle_workers"
889
)
890
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
11✔
891

892

893
class BlockingPortal(abc.BlockingPortal):
11✔
894
    def __new__(cls) -> BlockingPortal:
11✔
895
        return object.__new__(cls)
11✔
896

897
    def __init__(self) -> None:
11✔
898
        super().__init__()
11✔
899
        self._loop = get_running_loop()
11✔
900

901
    def _spawn_task_from_thread(
11✔
902
        self,
903
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
904
        args: tuple[Unpack[PosArgsT]],
905
        kwargs: dict[str, Any],
906
        name: object,
907
        future: Future[T_Retval],
908
    ) -> None:
909
        AsyncIOBackend.run_sync_from_thread(
11✔
910
            partial(self._task_group.start_soon, name=name),
911
            (self._call_func, func, args, kwargs, future),
912
            self._loop,
913
        )
914

915

916
#
917
# Subprocesses
918
#
919

920

921
@dataclass(eq=False)
11✔
922
class StreamReaderWrapper(abc.ByteReceiveStream):
11✔
923
    _stream: asyncio.StreamReader
11✔
924

925
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
926
        data = await self._stream.read(max_bytes)
10✔
927
        if data:
10✔
928
            return data
10✔
929
        else:
930
            raise EndOfStream
10✔
931

932
    async def aclose(self) -> None:
11✔
933
        self._stream.feed_eof()
10✔
934
        await AsyncIOBackend.checkpoint()
10✔
935

936

937
@dataclass(eq=False)
11✔
938
class StreamWriterWrapper(abc.ByteSendStream):
11✔
939
    _stream: asyncio.StreamWriter
11✔
940

941
    async def send(self, item: bytes) -> None:
11✔
942
        self._stream.write(item)
10✔
943
        await self._stream.drain()
10✔
944

945
    async def aclose(self) -> None:
11✔
946
        self._stream.close()
10✔
947
        await AsyncIOBackend.checkpoint()
10✔
948

949

950
@dataclass(eq=False)
11✔
951
class Process(abc.Process):
11✔
952
    _process: asyncio.subprocess.Process
11✔
953
    _stdin: StreamWriterWrapper | None
11✔
954
    _stdout: StreamReaderWrapper | None
11✔
955
    _stderr: StreamReaderWrapper | None
11✔
956

957
    async def aclose(self) -> None:
11✔
958
        with CancelScope(shield=True):
10✔
959
            if self._stdin:
10✔
960
                await self._stdin.aclose()
10✔
961
            if self._stdout:
10✔
962
                await self._stdout.aclose()
10✔
963
            if self._stderr:
10✔
964
                await self._stderr.aclose()
10✔
965

966
        try:
10✔
967
            await self.wait()
10✔
968
        except BaseException:
10✔
969
            self.kill()
10✔
970
            with CancelScope(shield=True):
10✔
971
                await self.wait()
10✔
972

973
            raise
10✔
974

975
    async def wait(self) -> int:
11✔
976
        return await self._process.wait()
10✔
977

978
    def terminate(self) -> None:
11✔
979
        self._process.terminate()
8✔
980

981
    def kill(self) -> None:
11✔
982
        self._process.kill()
10✔
983

984
    def send_signal(self, signal: int) -> None:
11✔
985
        self._process.send_signal(signal)
×
986

987
    @property
11✔
988
    def pid(self) -> int:
11✔
989
        return self._process.pid
×
990

991
    @property
11✔
992
    def returncode(self) -> int | None:
11✔
993
        return self._process.returncode
10✔
994

995
    @property
11✔
996
    def stdin(self) -> abc.ByteSendStream | None:
11✔
997
        return self._stdin
10✔
998

999
    @property
11✔
1000
    def stdout(self) -> abc.ByteReceiveStream | None:
11✔
1001
        return self._stdout
10✔
1002

1003
    @property
11✔
1004
    def stderr(self) -> abc.ByteReceiveStream | None:
11✔
1005
        return self._stderr
10✔
1006

1007

1008
def _forcibly_shutdown_process_pool_on_exit(
11✔
1009
    workers: set[Process], _task: object
1010
) -> None:
1011
    """
1012
    Forcibly shuts down worker processes belonging to this event loop."""
1013
    child_watcher: asyncio.AbstractChildWatcher | None = None
10✔
1014
    if sys.version_info < (3, 12):
10✔
1015
        try:
6✔
1016
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
6✔
1017
        except NotImplementedError:
1✔
1018
            pass
1✔
1019

1020
    # Close as much as possible (w/o async/await) to avoid warnings
1021
    for process in workers:
10✔
1022
        if process.returncode is None:
10✔
1023
            continue
10✔
1024

1025
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
1026
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
1027
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
1028
        process.kill()
×
1029
        if child_watcher:
×
1030
            child_watcher.remove_child_handler(process.pid)
×
1031

1032

1033
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
11✔
1034
    """
1035
    Shuts down worker processes belonging to this event loop.
1036

1037
    NOTE: this only works when the event loop was started using asyncio.run() or
1038
    anyio.run().
1039

1040
    """
1041
    process: abc.Process
1042
    try:
10✔
1043
        await sleep(math.inf)
10✔
1044
    except asyncio.CancelledError:
10✔
1045
        for process in workers:
10✔
1046
            if process.returncode is None:
10✔
1047
                process.kill()
10✔
1048

1049
        for process in workers:
10✔
1050
            await process.aclose()
10✔
1051

1052

1053
#
1054
# Sockets and networking
1055
#
1056

1057

1058
class StreamProtocol(asyncio.Protocol):
11✔
1059
    read_queue: deque[bytes]
11✔
1060
    read_event: asyncio.Event
11✔
1061
    write_event: asyncio.Event
11✔
1062
    exception: Exception | None = None
11✔
1063
    is_at_eof: bool = False
11✔
1064

1065
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
1066
        self.read_queue = deque()
11✔
1067
        self.read_event = asyncio.Event()
11✔
1068
        self.write_event = asyncio.Event()
11✔
1069
        self.write_event.set()
11✔
1070
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
11✔
1071

1072
    def connection_lost(self, exc: Exception | None) -> None:
11✔
1073
        if exc:
11✔
1074
            self.exception = BrokenResourceError()
11✔
1075
            self.exception.__cause__ = exc
11✔
1076

1077
        self.read_event.set()
11✔
1078
        self.write_event.set()
11✔
1079

1080
    def data_received(self, data: bytes) -> None:
11✔
1081
        # ProactorEventloop sometimes sends bytearray instead of bytes
1082
        self.read_queue.append(bytes(data))
11✔
1083
        self.read_event.set()
11✔
1084

1085
    def eof_received(self) -> bool | None:
11✔
1086
        self.is_at_eof = True
11✔
1087
        self.read_event.set()
11✔
1088
        return True
11✔
1089

1090
    def pause_writing(self) -> None:
11✔
1091
        self.write_event = asyncio.Event()
11✔
1092

1093
    def resume_writing(self) -> None:
11✔
1094
        self.write_event.set()
1✔
1095

1096

1097
class DatagramProtocol(asyncio.DatagramProtocol):
11✔
1098
    read_queue: deque[tuple[bytes, IPSockAddrType]]
11✔
1099
    read_event: asyncio.Event
11✔
1100
    write_event: asyncio.Event
11✔
1101
    exception: Exception | None = None
11✔
1102

1103
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
1104
        self.read_queue = deque(maxlen=100)  # arbitrary value
10✔
1105
        self.read_event = asyncio.Event()
10✔
1106
        self.write_event = asyncio.Event()
10✔
1107
        self.write_event.set()
10✔
1108

1109
    def connection_lost(self, exc: Exception | None) -> None:
11✔
1110
        self.read_event.set()
10✔
1111
        self.write_event.set()
10✔
1112

1113
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
11✔
1114
        addr = convert_ipv6_sockaddr(addr)
10✔
1115
        self.read_queue.append((data, addr))
10✔
1116
        self.read_event.set()
10✔
1117

1118
    def error_received(self, exc: Exception) -> None:
11✔
1119
        self.exception = exc
×
1120

1121
    def pause_writing(self) -> None:
11✔
1122
        self.write_event.clear()
×
1123

1124
    def resume_writing(self) -> None:
11✔
1125
        self.write_event.set()
×
1126

1127

1128
class SocketStream(abc.SocketStream):
11✔
1129
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
11✔
1130
        self._transport = transport
11✔
1131
        self._protocol = protocol
11✔
1132
        self._receive_guard = ResourceGuard("reading from")
11✔
1133
        self._send_guard = ResourceGuard("writing to")
11✔
1134
        self._closed = False
11✔
1135

1136
    @property
11✔
1137
    def _raw_socket(self) -> socket.socket:
11✔
1138
        return self._transport.get_extra_info("socket")
11✔
1139

1140
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
1141
        with self._receive_guard:
11✔
1142
            if (
11✔
1143
                not self._protocol.read_event.is_set()
1144
                and not self._transport.is_closing()
1145
                and not self._protocol.is_at_eof
1146
            ):
1147
                self._transport.resume_reading()
11✔
1148
                await self._protocol.read_event.wait()
11✔
1149
                self._transport.pause_reading()
11✔
1150
            else:
1151
                await AsyncIOBackend.checkpoint()
11✔
1152

1153
            try:
11✔
1154
                chunk = self._protocol.read_queue.popleft()
11✔
1155
            except IndexError:
11✔
1156
                if self._closed:
11✔
1157
                    raise ClosedResourceError from None
11✔
1158
                elif self._protocol.exception:
11✔
1159
                    raise self._protocol.exception from None
11✔
1160
                else:
1161
                    raise EndOfStream from None
11✔
1162

1163
            if len(chunk) > max_bytes:
11✔
1164
                # Split the oversized chunk
1165
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
9✔
1166
                self._protocol.read_queue.appendleft(leftover)
9✔
1167

1168
            # If the read queue is empty, clear the flag so that the next call will
1169
            # block until data is available
1170
            if not self._protocol.read_queue:
11✔
1171
                self._protocol.read_event.clear()
11✔
1172

1173
        return chunk
11✔
1174

1175
    async def send(self, item: bytes) -> None:
11✔
1176
        with self._send_guard:
11✔
1177
            await AsyncIOBackend.checkpoint()
11✔
1178

1179
            if self._closed:
11✔
1180
                raise ClosedResourceError
11✔
1181
            elif self._protocol.exception is not None:
11✔
1182
                raise self._protocol.exception
11✔
1183

1184
            try:
11✔
1185
                self._transport.write(item)
11✔
1186
            except RuntimeError as exc:
×
1187
                if self._transport.is_closing():
×
1188
                    raise BrokenResourceError from exc
×
1189
                else:
1190
                    raise
×
1191

1192
            await self._protocol.write_event.wait()
11✔
1193

1194
    async def send_eof(self) -> None:
11✔
1195
        try:
11✔
1196
            self._transport.write_eof()
11✔
1197
        except OSError:
×
1198
            pass
×
1199

1200
    async def aclose(self) -> None:
11✔
1201
        if not self._transport.is_closing():
11✔
1202
            self._closed = True
11✔
1203
            try:
11✔
1204
                self._transport.write_eof()
11✔
1205
            except OSError:
7✔
1206
                pass
7✔
1207

1208
            self._transport.close()
11✔
1209
            await sleep(0)
11✔
1210
            self._transport.abort()
11✔
1211

1212

1213
class _RawSocketMixin:
11✔
1214
    _receive_future: asyncio.Future | None = None
11✔
1215
    _send_future: asyncio.Future | None = None
11✔
1216
    _closing = False
11✔
1217

1218
    def __init__(self, raw_socket: socket.socket):
11✔
1219
        self.__raw_socket = raw_socket
8✔
1220
        self._receive_guard = ResourceGuard("reading from")
8✔
1221
        self._send_guard = ResourceGuard("writing to")
8✔
1222

1223
    @property
11✔
1224
    def _raw_socket(self) -> socket.socket:
11✔
1225
        return self.__raw_socket
8✔
1226

1227
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1228
        def callback(f: object) -> None:
8✔
1229
            del self._receive_future
8✔
1230
            loop.remove_reader(self.__raw_socket)
8✔
1231

1232
        f = self._receive_future = asyncio.Future()
8✔
1233
        loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1234
        f.add_done_callback(callback)
8✔
1235
        return f
8✔
1236

1237
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1238
        def callback(f: object) -> None:
8✔
1239
            del self._send_future
8✔
1240
            loop.remove_writer(self.__raw_socket)
8✔
1241

1242
        f = self._send_future = asyncio.Future()
8✔
1243
        loop.add_writer(self.__raw_socket, f.set_result, None)
8✔
1244
        f.add_done_callback(callback)
8✔
1245
        return f
8✔
1246

1247
    async def aclose(self) -> None:
11✔
1248
        if not self._closing:
8✔
1249
            self._closing = True
8✔
1250
            if self.__raw_socket.fileno() != -1:
8✔
1251
                self.__raw_socket.close()
8✔
1252

1253
            if self._receive_future:
8✔
1254
                self._receive_future.set_result(None)
8✔
1255
            if self._send_future:
8✔
1256
                self._send_future.set_result(None)
×
1257

1258

1259
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
11✔
1260
    async def send_eof(self) -> None:
11✔
1261
        with self._send_guard:
8✔
1262
            self._raw_socket.shutdown(socket.SHUT_WR)
8✔
1263

1264
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
1265
        loop = get_running_loop()
8✔
1266
        await AsyncIOBackend.checkpoint()
8✔
1267
        with self._receive_guard:
8✔
1268
            while True:
5✔
1269
                try:
8✔
1270
                    data = self._raw_socket.recv(max_bytes)
8✔
1271
                except BlockingIOError:
8✔
1272
                    await self._wait_until_readable(loop)
8✔
1273
                except OSError as exc:
8✔
1274
                    if self._closing:
8✔
1275
                        raise ClosedResourceError from None
8✔
1276
                    else:
1277
                        raise BrokenResourceError from exc
1✔
1278
                else:
1279
                    if not data:
8✔
1280
                        raise EndOfStream
8✔
1281

1282
                    return data
8✔
1283

1284
    async def send(self, item: bytes) -> None:
11✔
1285
        loop = get_running_loop()
8✔
1286
        await AsyncIOBackend.checkpoint()
8✔
1287
        with self._send_guard:
8✔
1288
            view = memoryview(item)
8✔
1289
            while view:
8✔
1290
                try:
8✔
1291
                    bytes_sent = self._raw_socket.send(view)
8✔
1292
                except BlockingIOError:
8✔
1293
                    await self._wait_until_writable(loop)
8✔
1294
                except OSError as exc:
8✔
1295
                    if self._closing:
8✔
1296
                        raise ClosedResourceError from None
8✔
1297
                    else:
1298
                        raise BrokenResourceError from exc
1✔
1299
                else:
1300
                    view = view[bytes_sent:]
8✔
1301

1302
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
11✔
1303
        if not isinstance(msglen, int) or msglen < 0:
8✔
1304
            raise ValueError("msglen must be a non-negative integer")
8✔
1305
        if not isinstance(maxfds, int) or maxfds < 1:
8✔
1306
            raise ValueError("maxfds must be a positive integer")
8✔
1307

1308
        loop = get_running_loop()
8✔
1309
        fds = array.array("i")
8✔
1310
        await AsyncIOBackend.checkpoint()
8✔
1311
        with self._receive_guard:
8✔
1312
            while True:
5✔
1313
                try:
8✔
1314
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
8✔
1315
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1316
                    )
1317
                except BlockingIOError:
8✔
1318
                    await self._wait_until_readable(loop)
8✔
1319
                except OSError as exc:
×
1320
                    if self._closing:
×
1321
                        raise ClosedResourceError from None
×
1322
                    else:
1323
                        raise BrokenResourceError from exc
×
1324
                else:
1325
                    if not message and not ancdata:
8✔
1326
                        raise EndOfStream
×
1327

1328
                    break
5✔
1329

1330
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
8✔
1331
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
8✔
1332
                raise RuntimeError(
×
1333
                    f"Received unexpected ancillary data; message = {message!r}, "
1334
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1335
                )
1336

1337
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
8✔
1338

1339
        return message, list(fds)
8✔
1340

1341
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
11✔
1342
        if not message:
8✔
1343
            raise ValueError("message must not be empty")
8✔
1344
        if not fds:
8✔
1345
            raise ValueError("fds must not be empty")
8✔
1346

1347
        loop = get_running_loop()
8✔
1348
        filenos: list[int] = []
8✔
1349
        for fd in fds:
8✔
1350
            if isinstance(fd, int):
8✔
1351
                filenos.append(fd)
×
1352
            elif isinstance(fd, IOBase):
8✔
1353
                filenos.append(fd.fileno())
8✔
1354

1355
        fdarray = array.array("i", filenos)
8✔
1356
        await AsyncIOBackend.checkpoint()
8✔
1357
        with self._send_guard:
8✔
1358
            while True:
5✔
1359
                try:
8✔
1360
                    # The ignore can be removed after mypy picks up
1361
                    # https://github.com/python/typeshed/pull/5545
1362
                    self._raw_socket.sendmsg(
8✔
1363
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1364
                    )
1365
                    break
8✔
1366
                except BlockingIOError:
×
1367
                    await self._wait_until_writable(loop)
×
1368
                except OSError as exc:
×
1369
                    if self._closing:
×
1370
                        raise ClosedResourceError from None
×
1371
                    else:
1372
                        raise BrokenResourceError from exc
×
1373

1374

1375
class TCPSocketListener(abc.SocketListener):
11✔
1376
    _accept_scope: CancelScope | None = None
11✔
1377
    _closed = False
11✔
1378

1379
    def __init__(self, raw_socket: socket.socket):
11✔
1380
        self.__raw_socket = raw_socket
11✔
1381
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
11✔
1382
        self._accept_guard = ResourceGuard("accepting connections from")
11✔
1383

1384
    @property
11✔
1385
    def _raw_socket(self) -> socket.socket:
11✔
1386
        return self.__raw_socket
11✔
1387

1388
    async def accept(self) -> abc.SocketStream:
11✔
1389
        if self._closed:
11✔
1390
            raise ClosedResourceError
11✔
1391

1392
        with self._accept_guard:
11✔
1393
            await AsyncIOBackend.checkpoint()
11✔
1394
            with CancelScope() as self._accept_scope:
11✔
1395
                try:
11✔
1396
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
11✔
1397
                except asyncio.CancelledError:
11✔
1398
                    # Workaround for https://bugs.python.org/issue41317
1399
                    try:
11✔
1400
                        self._loop.remove_reader(self._raw_socket)
11✔
1401
                    except (ValueError, NotImplementedError):
2✔
1402
                        pass
2✔
1403

1404
                    if self._closed:
11✔
1405
                        raise ClosedResourceError from None
10✔
1406

1407
                    raise
11✔
1408
                finally:
1409
                    self._accept_scope = None
11✔
1410

1411
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
11✔
1412
        transport, protocol = await self._loop.connect_accepted_socket(
11✔
1413
            StreamProtocol, client_sock
1414
        )
1415
        return SocketStream(transport, protocol)
11✔
1416

1417
    async def aclose(self) -> None:
11✔
1418
        if self._closed:
11✔
1419
            return
11✔
1420

1421
        self._closed = True
11✔
1422
        if self._accept_scope:
11✔
1423
            # Workaround for https://bugs.python.org/issue41317
1424
            try:
11✔
1425
                self._loop.remove_reader(self._raw_socket)
11✔
1426
            except (ValueError, NotImplementedError):
2✔
1427
                pass
2✔
1428

1429
            self._accept_scope.cancel()
10✔
1430
            await sleep(0)
10✔
1431

1432
        self._raw_socket.close()
11✔
1433

1434

1435
class UNIXSocketListener(abc.SocketListener):
11✔
1436
    def __init__(self, raw_socket: socket.socket):
11✔
1437
        self.__raw_socket = raw_socket
8✔
1438
        self._loop = get_running_loop()
8✔
1439
        self._accept_guard = ResourceGuard("accepting connections from")
8✔
1440
        self._closed = False
8✔
1441

1442
    async def accept(self) -> abc.SocketStream:
11✔
1443
        await AsyncIOBackend.checkpoint()
8✔
1444
        with self._accept_guard:
8✔
1445
            while True:
5✔
1446
                try:
8✔
1447
                    client_sock, _ = self.__raw_socket.accept()
8✔
1448
                    client_sock.setblocking(False)
8✔
1449
                    return UNIXSocketStream(client_sock)
8✔
1450
                except BlockingIOError:
8✔
1451
                    f: asyncio.Future = asyncio.Future()
8✔
1452
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1453
                    f.add_done_callback(
8✔
1454
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1455
                    )
1456
                    await f
8✔
1457
                except OSError as exc:
×
1458
                    if self._closed:
×
1459
                        raise ClosedResourceError from None
×
1460
                    else:
1461
                        raise BrokenResourceError from exc
1✔
1462

1463
    async def aclose(self) -> None:
11✔
1464
        self._closed = True
8✔
1465
        self.__raw_socket.close()
8✔
1466

1467
    @property
11✔
1468
    def _raw_socket(self) -> socket.socket:
11✔
1469
        return self.__raw_socket
8✔
1470

1471

1472
class UDPSocket(abc.UDPSocket):
11✔
1473
    def __init__(
11✔
1474
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1475
    ):
1476
        self._transport = transport
10✔
1477
        self._protocol = protocol
10✔
1478
        self._receive_guard = ResourceGuard("reading from")
10✔
1479
        self._send_guard = ResourceGuard("writing to")
10✔
1480
        self._closed = False
10✔
1481

1482
    @property
11✔
1483
    def _raw_socket(self) -> socket.socket:
11✔
1484
        return self._transport.get_extra_info("socket")
10✔
1485

1486
    async def aclose(self) -> None:
11✔
1487
        if not self._transport.is_closing():
10✔
1488
            self._closed = True
10✔
1489
            self._transport.close()
10✔
1490

1491
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
11✔
1492
        with self._receive_guard:
10✔
1493
            await AsyncIOBackend.checkpoint()
10✔
1494

1495
            # If the buffer is empty, ask for more data
1496
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1497
                self._protocol.read_event.clear()
10✔
1498
                await self._protocol.read_event.wait()
10✔
1499

1500
            try:
10✔
1501
                return self._protocol.read_queue.popleft()
10✔
1502
            except IndexError:
10✔
1503
                if self._closed:
10✔
1504
                    raise ClosedResourceError from None
10✔
1505
                else:
1506
                    raise BrokenResourceError from None
1✔
1507

1508
    async def send(self, item: UDPPacketType) -> None:
11✔
1509
        with self._send_guard:
10✔
1510
            await AsyncIOBackend.checkpoint()
10✔
1511
            await self._protocol.write_event.wait()
10✔
1512
            if self._closed:
10✔
1513
                raise ClosedResourceError
10✔
1514
            elif self._transport.is_closing():
10✔
1515
                raise BrokenResourceError
×
1516
            else:
1517
                self._transport.sendto(*item)
10✔
1518

1519

1520
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
11✔
1521
    def __init__(
11✔
1522
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1523
    ):
1524
        self._transport = transport
10✔
1525
        self._protocol = protocol
10✔
1526
        self._receive_guard = ResourceGuard("reading from")
10✔
1527
        self._send_guard = ResourceGuard("writing to")
10✔
1528
        self._closed = False
10✔
1529

1530
    @property
11✔
1531
    def _raw_socket(self) -> socket.socket:
11✔
1532
        return self._transport.get_extra_info("socket")
10✔
1533

1534
    async def aclose(self) -> None:
11✔
1535
        if not self._transport.is_closing():
10✔
1536
            self._closed = True
10✔
1537
            self._transport.close()
10✔
1538

1539
    async def receive(self) -> bytes:
11✔
1540
        with self._receive_guard:
10✔
1541
            await AsyncIOBackend.checkpoint()
10✔
1542

1543
            # If the buffer is empty, ask for more data
1544
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1545
                self._protocol.read_event.clear()
10✔
1546
                await self._protocol.read_event.wait()
10✔
1547

1548
            try:
10✔
1549
                packet = self._protocol.read_queue.popleft()
10✔
1550
            except IndexError:
10✔
1551
                if self._closed:
10✔
1552
                    raise ClosedResourceError from None
10✔
1553
                else:
1554
                    raise BrokenResourceError from None
×
1555

1556
            return packet[0]
10✔
1557

1558
    async def send(self, item: bytes) -> None:
11✔
1559
        with self._send_guard:
10✔
1560
            await AsyncIOBackend.checkpoint()
10✔
1561
            await self._protocol.write_event.wait()
10✔
1562
            if self._closed:
10✔
1563
                raise ClosedResourceError
10✔
1564
            elif self._transport.is_closing():
10✔
1565
                raise BrokenResourceError
×
1566
            else:
1567
                self._transport.sendto(item)
10✔
1568

1569

1570
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
11✔
1571
    async def receive(self) -> UNIXDatagramPacketType:
11✔
1572
        loop = get_running_loop()
8✔
1573
        await AsyncIOBackend.checkpoint()
8✔
1574
        with self._receive_guard:
8✔
1575
            while True:
5✔
1576
                try:
8✔
1577
                    data = self._raw_socket.recvfrom(65536)
8✔
1578
                except BlockingIOError:
8✔
1579
                    await self._wait_until_readable(loop)
8✔
1580
                except OSError as exc:
8✔
1581
                    if self._closing:
8✔
1582
                        raise ClosedResourceError from None
8✔
1583
                    else:
1584
                        raise BrokenResourceError from exc
1✔
1585
                else:
1586
                    return data
8✔
1587

1588
    async def send(self, item: UNIXDatagramPacketType) -> None:
11✔
1589
        loop = get_running_loop()
8✔
1590
        await AsyncIOBackend.checkpoint()
8✔
1591
        with self._send_guard:
8✔
1592
            while True:
5✔
1593
                try:
8✔
1594
                    self._raw_socket.sendto(*item)
8✔
1595
                except BlockingIOError:
8✔
1596
                    await self._wait_until_writable(loop)
×
1597
                except OSError as exc:
8✔
1598
                    if self._closing:
8✔
1599
                        raise ClosedResourceError from None
8✔
1600
                    else:
1601
                        raise BrokenResourceError from exc
1✔
1602
                else:
1603
                    return
8✔
1604

1605

1606
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
11✔
1607
    async def receive(self) -> bytes:
11✔
1608
        loop = get_running_loop()
8✔
1609
        await AsyncIOBackend.checkpoint()
8✔
1610
        with self._receive_guard:
8✔
1611
            while True:
5✔
1612
                try:
8✔
1613
                    data = self._raw_socket.recv(65536)
8✔
1614
                except BlockingIOError:
8✔
1615
                    await self._wait_until_readable(loop)
8✔
1616
                except OSError as exc:
8✔
1617
                    if self._closing:
8✔
1618
                        raise ClosedResourceError from None
8✔
1619
                    else:
1620
                        raise BrokenResourceError from exc
1✔
1621
                else:
1622
                    return data
8✔
1623

1624
    async def send(self, item: bytes) -> None:
11✔
1625
        loop = get_running_loop()
8✔
1626
        await AsyncIOBackend.checkpoint()
8✔
1627
        with self._send_guard:
8✔
1628
            while True:
5✔
1629
                try:
8✔
1630
                    self._raw_socket.send(item)
8✔
1631
                except BlockingIOError:
8✔
1632
                    await self._wait_until_writable(loop)
×
1633
                except OSError as exc:
8✔
1634
                    if self._closing:
8✔
1635
                        raise ClosedResourceError from None
8✔
1636
                    else:
1637
                        raise BrokenResourceError from exc
1✔
1638
                else:
1639
                    return
8✔
1640

1641

1642
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
11✔
1643
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
11✔
1644

1645

1646
#
1647
# Synchronization
1648
#
1649

1650

1651
class Event(BaseEvent):
11✔
1652
    def __new__(cls) -> Event:
11✔
1653
        return object.__new__(cls)
11✔
1654

1655
    def __init__(self) -> None:
11✔
1656
        self._event = asyncio.Event()
11✔
1657

1658
    def set(self) -> None:
11✔
1659
        self._event.set()
11✔
1660

1661
    def is_set(self) -> bool:
11✔
1662
        return self._event.is_set()
11✔
1663

1664
    async def wait(self) -> None:
11✔
1665
        if self.is_set():
11✔
1666
            await AsyncIOBackend.checkpoint()
11✔
1667
        else:
1668
            await self._event.wait()
11✔
1669

1670
    def statistics(self) -> EventStatistics:
11✔
1671
        return EventStatistics(len(self._event._waiters))
10✔
1672

1673

1674
class Lock(BaseLock):
11✔
1675
    def __new__(cls, *, fast_acquire: bool = False) -> Lock:
11✔
1676
        return object.__new__(cls)
10✔
1677

1678
    def __init__(self, *, fast_acquire: bool = False) -> None:
11✔
1679
        self._fast_acquire = fast_acquire
10✔
1680
        self._owner_task: asyncio.Task | None = None
10✔
1681
        self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
10✔
1682

1683
    async def acquire(self) -> None:
11✔
1684
        if self._owner_task is None and not self._waiters:
10✔
1685
            await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1686
            self._owner_task = current_task()
10✔
1687

1688
            # Unless on the "fast path", yield control of the event loop so that other
1689
            # tasks can run too
1690
            if not self._fast_acquire:
10✔
1691
                try:
10✔
1692
                    await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1693
                except CancelledError:
10✔
1694
                    self.release()
10✔
1695
                    raise
10✔
1696

1697
            return
10✔
1698

1699
        task = cast(asyncio.Task, current_task())
10✔
1700
        fut: asyncio.Future[None] = asyncio.Future()
10✔
1701
        item = task, fut
10✔
1702
        self._waiters.append(item)
10✔
1703
        try:
10✔
1704
            await fut
10✔
1705
        except CancelledError:
10✔
1706
            self._waiters.remove(item)
10✔
1707
            if self._owner_task is task:
10✔
1708
                self.release()
10✔
1709

1710
            raise
10✔
1711

1712
        self._waiters.remove(item)
10✔
1713

1714
    def acquire_nowait(self) -> None:
11✔
1715
        if self._owner_task is None and not self._waiters:
10✔
1716
            self._owner_task = current_task()
10✔
1717
            return
10✔
1718

1719
        raise WouldBlock
10✔
1720

1721
    def locked(self) -> bool:
11✔
1722
        return self._owner_task is not None
10✔
1723

1724
    def release(self) -> None:
11✔
1725
        if self._owner_task != current_task():
10✔
1726
            raise RuntimeError("The current task is not holding this lock")
×
1727

1728
        for task, fut in self._waiters:
10✔
1729
            if not fut.cancelled():
10✔
1730
                self._owner_task = task
10✔
1731
                fut.set_result(None)
10✔
1732
                return
10✔
1733

1734
        self._owner_task = None
10✔
1735

1736
    def statistics(self) -> LockStatistics:
11✔
1737
        task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
10✔
1738
        return LockStatistics(self.locked(), task_info, len(self._waiters))
10✔
1739

1740

1741
class Semaphore(BaseSemaphore):
11✔
1742
    def __new__(
11✔
1743
        cls,
1744
        initial_value: int,
1745
        *,
1746
        max_value: int | None = None,
1747
        fast_acquire: bool = False,
1748
    ) -> Semaphore:
1749
        return object.__new__(cls)
10✔
1750

1751
    def __init__(
11✔
1752
        self,
1753
        initial_value: int,
1754
        *,
1755
        max_value: int | None = None,
1756
        fast_acquire: bool = False,
1757
    ):
1758
        super().__init__(initial_value, max_value=max_value)
10✔
1759
        self._value = initial_value
10✔
1760
        self._max_value = max_value
10✔
1761
        self._fast_acquire = fast_acquire
10✔
1762
        self._waiters: deque[asyncio.Future[None]] = deque()
10✔
1763

1764
    async def acquire(self) -> None:
11✔
1765
        if self._value > 0 and not self._waiters:
10✔
1766
            await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1767
            self._value -= 1
10✔
1768

1769
            # Unless on the "fast path", yield control of the event loop so that other
1770
            # tasks can run too
1771
            if not self._fast_acquire:
10✔
1772
                try:
10✔
1773
                    await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1774
                except CancelledError:
10✔
1775
                    self.release()
10✔
1776
                    raise
10✔
1777

1778
            return
10✔
1779

1780
        fut: asyncio.Future[None] = asyncio.Future()
10✔
1781
        self._waiters.append(fut)
10✔
1782
        try:
10✔
1783
            await fut
10✔
1784
        except CancelledError:
10✔
1785
            try:
10✔
1786
                self._waiters.remove(fut)
10✔
1787
            except ValueError:
10✔
1788
                self.release()
10✔
1789

1790
            raise
10✔
1791

1792
    def acquire_nowait(self) -> None:
11✔
1793
        if self._value == 0:
10✔
1794
            raise WouldBlock
10✔
1795

1796
        self._value -= 1
10✔
1797

1798
    def release(self) -> None:
11✔
1799
        if self._max_value is not None and self._value == self._max_value:
10✔
1800
            raise ValueError("semaphore released too many times")
10✔
1801

1802
        for fut in self._waiters:
10✔
1803
            if not fut.cancelled():
10✔
1804
                fut.set_result(None)
10✔
1805
                self._waiters.remove(fut)
10✔
1806
                return
10✔
1807

1808
        self._value += 1
10✔
1809

1810
    @property
11✔
1811
    def value(self) -> int:
11✔
1812
        return self._value
10✔
1813

1814
    @property
11✔
1815
    def max_value(self) -> int | None:
11✔
1816
        return self._max_value
10✔
1817

1818
    def statistics(self) -> SemaphoreStatistics:
11✔
1819
        return SemaphoreStatistics(len(self._waiters))
10✔
1820

1821

1822
class CapacityLimiter(BaseCapacityLimiter):
11✔
1823
    _total_tokens: float = 0
11✔
1824

1825
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
11✔
1826
        return object.__new__(cls)
11✔
1827

1828
    def __init__(self, total_tokens: float):
11✔
1829
        self._borrowers: set[Any] = set()
11✔
1830
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
11✔
1831
        self.total_tokens = total_tokens
11✔
1832

1833
    async def __aenter__(self) -> None:
11✔
1834
        await self.acquire()
11✔
1835

1836
    async def __aexit__(
11✔
1837
        self,
1838
        exc_type: type[BaseException] | None,
1839
        exc_val: BaseException | None,
1840
        exc_tb: TracebackType | None,
1841
    ) -> None:
1842
        self.release()
11✔
1843

1844
    @property
11✔
1845
    def total_tokens(self) -> float:
11✔
1846
        return self._total_tokens
10✔
1847

1848
    @total_tokens.setter
11✔
1849
    def total_tokens(self, value: float) -> None:
11✔
1850
        if not isinstance(value, int) and not math.isinf(value):
11✔
1851
            raise TypeError("total_tokens must be an int or math.inf")
10✔
1852
        if value < 1:
11✔
1853
            raise ValueError("total_tokens must be >= 1")
10✔
1854

1855
        waiters_to_notify = max(value - self._total_tokens, 0)
11✔
1856
        self._total_tokens = value
11✔
1857

1858
        # Notify waiting tasks that they have acquired the limiter
1859
        while self._wait_queue and waiters_to_notify:
11✔
1860
            event = self._wait_queue.popitem(last=False)[1]
10✔
1861
            event.set()
10✔
1862
            waiters_to_notify -= 1
10✔
1863

1864
    @property
11✔
1865
    def borrowed_tokens(self) -> int:
11✔
1866
        return len(self._borrowers)
10✔
1867

1868
    @property
11✔
1869
    def available_tokens(self) -> float:
11✔
1870
        return self._total_tokens - len(self._borrowers)
10✔
1871

1872
    def acquire_nowait(self) -> None:
11✔
1873
        self.acquire_on_behalf_of_nowait(current_task())
×
1874

1875
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
11✔
1876
        if borrower in self._borrowers:
11✔
1877
            raise RuntimeError(
10✔
1878
                "this borrower is already holding one of this CapacityLimiter's "
1879
                "tokens"
1880
            )
1881

1882
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
11✔
1883
            raise WouldBlock
10✔
1884

1885
        self._borrowers.add(borrower)
11✔
1886

1887
    async def acquire(self) -> None:
11✔
1888
        return await self.acquire_on_behalf_of(current_task())
11✔
1889

1890
    async def acquire_on_behalf_of(self, borrower: object) -> None:
11✔
1891
        await AsyncIOBackend.checkpoint_if_cancelled()
11✔
1892
        try:
11✔
1893
            self.acquire_on_behalf_of_nowait(borrower)
11✔
1894
        except WouldBlock:
10✔
1895
            event = asyncio.Event()
10✔
1896
            self._wait_queue[borrower] = event
10✔
1897
            try:
10✔
1898
                await event.wait()
10✔
1899
            except BaseException:
×
1900
                self._wait_queue.pop(borrower, None)
×
1901
                raise
×
1902

1903
            self._borrowers.add(borrower)
10✔
1904
        else:
1905
            try:
11✔
1906
                await AsyncIOBackend.cancel_shielded_checkpoint()
11✔
1907
            except BaseException:
10✔
1908
                self.release()
10✔
1909
                raise
10✔
1910

1911
    def release(self) -> None:
11✔
1912
        self.release_on_behalf_of(current_task())
11✔
1913

1914
    def release_on_behalf_of(self, borrower: object) -> None:
11✔
1915
        try:
11✔
1916
            self._borrowers.remove(borrower)
11✔
1917
        except KeyError:
10✔
1918
            raise RuntimeError(
10✔
1919
                "this borrower isn't holding any of this CapacityLimiter's tokens"
1920
            ) from None
1921

1922
        # Notify the next task in line if this limiter has free capacity now
1923
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
11✔
1924
            event = self._wait_queue.popitem(last=False)[1]
10✔
1925
            event.set()
10✔
1926

1927
    def statistics(self) -> CapacityLimiterStatistics:
11✔
1928
        return CapacityLimiterStatistics(
10✔
1929
            self.borrowed_tokens,
1930
            self.total_tokens,
1931
            tuple(self._borrowers),
1932
            len(self._wait_queue),
1933
        )
1934

1935

1936
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
11✔
1937

1938

1939
#
1940
# Operating system signals
1941
#
1942

1943

1944
class _SignalReceiver:
11✔
1945
    def __init__(self, signals: tuple[Signals, ...]):
11✔
1946
        self._signals = signals
9✔
1947
        self._loop = get_running_loop()
9✔
1948
        self._signal_queue: deque[Signals] = deque()
9✔
1949
        self._future: asyncio.Future = asyncio.Future()
9✔
1950
        self._handled_signals: set[Signals] = set()
9✔
1951

1952
    def _deliver(self, signum: Signals) -> None:
11✔
1953
        self._signal_queue.append(signum)
9✔
1954
        if not self._future.done():
9✔
1955
            self._future.set_result(None)
9✔
1956

1957
    def __enter__(self) -> _SignalReceiver:
11✔
1958
        for sig in set(self._signals):
9✔
1959
            self._loop.add_signal_handler(sig, self._deliver, sig)
9✔
1960
            self._handled_signals.add(sig)
9✔
1961

1962
        return self
9✔
1963

1964
    def __exit__(
11✔
1965
        self,
1966
        exc_type: type[BaseException] | None,
1967
        exc_val: BaseException | None,
1968
        exc_tb: TracebackType | None,
1969
    ) -> bool | None:
1970
        for sig in self._handled_signals:
9✔
1971
            self._loop.remove_signal_handler(sig)
9✔
1972
        return None
9✔
1973

1974
    def __aiter__(self) -> _SignalReceiver:
11✔
1975
        return self
9✔
1976

1977
    async def __anext__(self) -> Signals:
11✔
1978
        await AsyncIOBackend.checkpoint()
9✔
1979
        if not self._signal_queue:
9✔
1980
            self._future = asyncio.Future()
×
1981
            await self._future
×
1982

1983
        return self._signal_queue.popleft()
9✔
1984

1985

1986
#
1987
# Testing and debugging
1988
#
1989

1990

1991
class AsyncIOTaskInfo(TaskInfo):
11✔
1992
    def __init__(self, task: asyncio.Task):
11✔
1993
        task_state = _task_states.get(task)
11✔
1994
        if task_state is None:
11✔
1995
            parent_id = None
11✔
1996
        else:
1997
            parent_id = task_state.parent_id
11✔
1998

1999
        super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
11✔
2000
        self._task = weakref.ref(task)
11✔
2001

2002
    def has_pending_cancellation(self) -> bool:
11✔
2003
        if not (task := self._task()):
11✔
2004
            # If the task isn't around anymore, it won't have a pending cancellation
2005
            return False
×
2006

2007
        if sys.version_info >= (3, 11):
11✔
2008
            if task.cancelling():
5✔
2009
                return True
5✔
2010
        elif (
6✔
2011
            isinstance(task._fut_waiter, asyncio.Future)
2012
            and task._fut_waiter.cancelled()
2013
        ):
2014
            return True
6✔
2015

2016
        if task_state := _task_states.get(task):
11✔
2017
            if cancel_scope := task_state.cancel_scope:
11✔
2018
                return cancel_scope.cancel_called or (
11✔
2019
                    not cancel_scope.shield and cancel_scope._parent_cancelled()
2020
                )
2021

2022
        return False
11✔
2023

2024

2025
class TestRunner(abc.TestRunner):
11✔
2026
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
11✔
2027

2028
    def __init__(
11✔
2029
        self,
2030
        *,
2031
        debug: bool | None = None,
2032
        use_uvloop: bool = False,
2033
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
2034
    ) -> None:
2035
        if use_uvloop and loop_factory is None:
11✔
2036
            import uvloop
×
2037

2038
            loop_factory = uvloop.new_event_loop
×
2039

2040
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
11✔
2041
        self._exceptions: list[BaseException] = []
11✔
2042
        self._runner_task: asyncio.Task | None = None
11✔
2043

2044
    def __enter__(self) -> TestRunner:
11✔
2045
        self._runner.__enter__()
11✔
2046
        self.get_loop().set_exception_handler(self._exception_handler)
11✔
2047
        return self
11✔
2048

2049
    def __exit__(
11✔
2050
        self,
2051
        exc_type: type[BaseException] | None,
2052
        exc_val: BaseException | None,
2053
        exc_tb: TracebackType | None,
2054
    ) -> None:
2055
        self._runner.__exit__(exc_type, exc_val, exc_tb)
11✔
2056

2057
    def get_loop(self) -> AbstractEventLoop:
11✔
2058
        return self._runner.get_loop()
11✔
2059

2060
    def _exception_handler(
11✔
2061
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
2062
    ) -> None:
2063
        if isinstance(context.get("exception"), Exception):
11✔
2064
            self._exceptions.append(context["exception"])
11✔
2065
        else:
2066
            loop.default_exception_handler(context)
11✔
2067

2068
    def _raise_async_exceptions(self) -> None:
11✔
2069
        # Re-raise any exceptions raised in asynchronous callbacks
2070
        if self._exceptions:
11✔
2071
            exceptions, self._exceptions = self._exceptions, []
11✔
2072
            if len(exceptions) == 1:
11✔
2073
                raise exceptions[0]
11✔
2074
            elif exceptions:
×
2075
                raise BaseExceptionGroup(
×
2076
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
2077
                )
2078

2079
    async def _run_tests_and_fixtures(
11✔
2080
        self,
2081
        receive_stream: MemoryObjectReceiveStream[
2082
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
2083
        ],
2084
    ) -> None:
2085
        with receive_stream, self._send_stream:
11✔
2086
            async for coro, future in receive_stream:
11✔
2087
                try:
11✔
2088
                    retval = await coro
11✔
2089
                except BaseException as exc:
11✔
2090
                    if not future.cancelled():
11✔
2091
                        future.set_exception(exc)
11✔
2092
                else:
2093
                    if not future.cancelled():
11✔
2094
                        future.set_result(retval)
11✔
2095

2096
    async def _call_in_runner_task(
11✔
2097
        self,
2098
        func: Callable[P, Awaitable[T_Retval]],
2099
        *args: P.args,
2100
        **kwargs: P.kwargs,
2101
    ) -> T_Retval:
2102
        if not self._runner_task:
11✔
2103
            self._send_stream, receive_stream = create_memory_object_stream[
11✔
2104
                Tuple[Awaitable[Any], asyncio.Future]
2105
            ](1)
2106
            self._runner_task = self.get_loop().create_task(
11✔
2107
                self._run_tests_and_fixtures(receive_stream)
2108
            )
2109

2110
        coro = func(*args, **kwargs)
11✔
2111
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
11✔
2112
        self._send_stream.send_nowait((coro, future))
11✔
2113
        return await future
11✔
2114

2115
    def run_asyncgen_fixture(
11✔
2116
        self,
2117
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
2118
        kwargs: dict[str, Any],
2119
    ) -> Iterable[T_Retval]:
2120
        asyncgen = fixture_func(**kwargs)
11✔
2121
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
11✔
2122
            self._call_in_runner_task(asyncgen.asend, None)
2123
        )
2124
        self._raise_async_exceptions()
11✔
2125

2126
        yield fixturevalue
11✔
2127

2128
        try:
11✔
2129
            self.get_loop().run_until_complete(
11✔
2130
                self._call_in_runner_task(asyncgen.asend, None)
2131
            )
2132
        except StopAsyncIteration:
11✔
2133
            self._raise_async_exceptions()
11✔
2134
        else:
2135
            self.get_loop().run_until_complete(asyncgen.aclose())
×
2136
            raise RuntimeError("Async generator fixture did not stop")
×
2137

2138
    def run_fixture(
11✔
2139
        self,
2140
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
2141
        kwargs: dict[str, Any],
2142
    ) -> T_Retval:
2143
        retval = self.get_loop().run_until_complete(
11✔
2144
            self._call_in_runner_task(fixture_func, **kwargs)
2145
        )
2146
        self._raise_async_exceptions()
11✔
2147
        return retval
11✔
2148

2149
    def run_test(
11✔
2150
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
2151
    ) -> None:
2152
        try:
11✔
2153
            self.get_loop().run_until_complete(
11✔
2154
                self._call_in_runner_task(test_func, **kwargs)
2155
            )
2156
        except Exception as exc:
11✔
2157
            self._exceptions.append(exc)
11✔
2158

2159
        self._raise_async_exceptions()
11✔
2160

2161

2162
class AsyncIOBackend(AsyncBackend):
11✔
2163
    @classmethod
11✔
2164
    def run(
11✔
2165
        cls,
2166
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2167
        args: tuple[Unpack[PosArgsT]],
2168
        kwargs: dict[str, Any],
2169
        options: dict[str, Any],
2170
    ) -> T_Retval:
2171
        @wraps(func)
11✔
2172
        async def wrapper() -> T_Retval:
11✔
2173
            task = cast(asyncio.Task, current_task())
11✔
2174
            task.set_name(get_callable_name(func))
11✔
2175
            _task_states[task] = TaskState(None, None)
11✔
2176

2177
            try:
11✔
2178
                return await func(*args)
11✔
2179
            finally:
2180
                del _task_states[task]
11✔
2181

2182
        debug = options.get("debug", None)
11✔
2183
        loop_factory = options.get("loop_factory", None)
11✔
2184
        if loop_factory is None and options.get("use_uvloop", False):
11✔
2185
            import uvloop
7✔
2186

2187
            loop_factory = uvloop.new_event_loop
7✔
2188

2189
        with Runner(debug=debug, loop_factory=loop_factory) as runner:
11✔
2190
            return runner.run(wrapper())
11✔
2191

2192
    @classmethod
11✔
2193
    def current_token(cls) -> object:
11✔
2194
        return get_running_loop()
11✔
2195

2196
    @classmethod
11✔
2197
    def current_time(cls) -> float:
11✔
2198
        return get_running_loop().time()
11✔
2199

2200
    @classmethod
11✔
2201
    def cancelled_exception_class(cls) -> type[BaseException]:
11✔
2202
        return CancelledError
11✔
2203

2204
    @classmethod
11✔
2205
    async def checkpoint(cls) -> None:
11✔
2206
        await sleep(0)
11✔
2207

2208
    @classmethod
11✔
2209
    async def checkpoint_if_cancelled(cls) -> None:
11✔
2210
        task = current_task()
11✔
2211
        if task is None:
11✔
2212
            return
×
2213

2214
        try:
11✔
2215
            cancel_scope = _task_states[task].cancel_scope
11✔
2216
        except KeyError:
11✔
2217
            return
11✔
2218

2219
        while cancel_scope:
11✔
2220
            if cancel_scope.cancel_called:
11✔
2221
                await sleep(0)
11✔
2222
            elif cancel_scope.shield:
11✔
2223
                break
10✔
2224
            else:
2225
                cancel_scope = cancel_scope._parent_scope
11✔
2226

2227
    @classmethod
11✔
2228
    async def cancel_shielded_checkpoint(cls) -> None:
11✔
2229
        with CancelScope(shield=True):
11✔
2230
            await sleep(0)
11✔
2231

2232
    @classmethod
11✔
2233
    async def sleep(cls, delay: float) -> None:
11✔
2234
        await sleep(delay)
11✔
2235

2236
    @classmethod
11✔
2237
    def create_cancel_scope(
11✔
2238
        cls, *, deadline: float = math.inf, shield: bool = False
2239
    ) -> CancelScope:
2240
        return CancelScope(deadline=deadline, shield=shield)
11✔
2241

2242
    @classmethod
11✔
2243
    def current_effective_deadline(cls) -> float:
11✔
2244
        try:
10✔
2245
            cancel_scope = _task_states[
10✔
2246
                current_task()  # type: ignore[index]
2247
            ].cancel_scope
2248
        except KeyError:
×
2249
            return math.inf
×
2250

2251
        deadline = math.inf
10✔
2252
        while cancel_scope:
10✔
2253
            deadline = min(deadline, cancel_scope.deadline)
10✔
2254
            if cancel_scope._cancel_called:
10✔
2255
                deadline = -math.inf
10✔
2256
                break
10✔
2257
            elif cancel_scope.shield:
10✔
2258
                break
10✔
2259
            else:
2260
                cancel_scope = cancel_scope._parent_scope
10✔
2261

2262
        return deadline
10✔
2263

2264
    @classmethod
11✔
2265
    def create_task_group(cls) -> abc.TaskGroup:
11✔
2266
        return TaskGroup()
11✔
2267

2268
    @classmethod
11✔
2269
    def create_event(cls) -> abc.Event:
11✔
2270
        return Event()
11✔
2271

2272
    @classmethod
11✔
2273
    def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
11✔
2274
        return Lock(fast_acquire=fast_acquire)
10✔
2275

2276
    @classmethod
11✔
2277
    def create_semaphore(
11✔
2278
        cls,
2279
        initial_value: int,
2280
        *,
2281
        max_value: int | None = None,
2282
        fast_acquire: bool = False,
2283
    ) -> abc.Semaphore:
2284
        return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
10✔
2285

2286
    @classmethod
11✔
2287
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
11✔
2288
        return CapacityLimiter(total_tokens)
10✔
2289

2290
    @classmethod
11✔
2291
    async def run_sync_in_worker_thread(
11✔
2292
        cls,
2293
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2294
        args: tuple[Unpack[PosArgsT]],
2295
        abandon_on_cancel: bool = False,
2296
        limiter: abc.CapacityLimiter | None = None,
2297
    ) -> T_Retval:
2298
        await cls.checkpoint()
11✔
2299

2300
        # If this is the first run in this event loop thread, set up the necessary
2301
        # variables
2302
        try:
11✔
2303
            idle_workers = _threadpool_idle_workers.get()
11✔
2304
            workers = _threadpool_workers.get()
11✔
2305
        except LookupError:
11✔
2306
            idle_workers = deque()
11✔
2307
            workers = set()
11✔
2308
            _threadpool_idle_workers.set(idle_workers)
11✔
2309
            _threadpool_workers.set(workers)
11✔
2310

2311
        async with limiter or cls.current_default_thread_limiter():
11✔
2312
            with CancelScope(shield=not abandon_on_cancel) as scope:
11✔
2313
                future: asyncio.Future = asyncio.Future()
11✔
2314
                root_task = find_root_task()
11✔
2315
                if not idle_workers:
11✔
2316
                    worker = WorkerThread(root_task, workers, idle_workers)
11✔
2317
                    worker.start()
11✔
2318
                    workers.add(worker)
11✔
2319
                    root_task.add_done_callback(worker.stop)
11✔
2320
                else:
2321
                    worker = idle_workers.pop()
11✔
2322

2323
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2324
                    # seconds or longer
2325
                    now = cls.current_time()
11✔
2326
                    while idle_workers:
11✔
2327
                        if (
9✔
2328
                            now - idle_workers[0].idle_since
2329
                            < WorkerThread.MAX_IDLE_TIME
2330
                        ):
2331
                            break
9✔
2332

2333
                        expired_worker = idle_workers.popleft()
×
2334
                        expired_worker.root_task.remove_done_callback(
×
2335
                            expired_worker.stop
2336
                        )
2337
                        expired_worker.stop()
×
2338

2339
                context = copy_context()
11✔
2340
                context.run(sniffio.current_async_library_cvar.set, None)
11✔
2341
                if abandon_on_cancel or scope._parent_scope is None:
11✔
2342
                    worker_scope = scope
11✔
2343
                else:
2344
                    worker_scope = scope._parent_scope
11✔
2345

2346
                worker.queue.put_nowait((context, func, args, future, worker_scope))
11✔
2347
                return await future
11✔
2348

2349
    @classmethod
11✔
2350
    def check_cancelled(cls) -> None:
11✔
2351
        scope: CancelScope | None = threadlocals.current_cancel_scope
11✔
2352
        while scope is not None:
11✔
2353
            if scope.cancel_called:
11✔
2354
                raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
11✔
2355

2356
            if scope.shield:
11✔
2357
                return
×
2358

2359
            scope = scope._parent_scope
11✔
2360

2361
    @classmethod
11✔
2362
    def run_async_from_thread(
11✔
2363
        cls,
2364
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2365
        args: tuple[Unpack[PosArgsT]],
2366
        token: object,
2367
    ) -> T_Retval:
2368
        async def task_wrapper(scope: CancelScope) -> T_Retval:
11✔
2369
            __tracebackhide__ = True
11✔
2370
            task = cast(asyncio.Task, current_task())
11✔
2371
            _task_states[task] = TaskState(None, scope)
11✔
2372
            scope._tasks.add(task)
11✔
2373
            try:
11✔
2374
                return await func(*args)
11✔
2375
            except CancelledError as exc:
11✔
2376
                raise concurrent.futures.CancelledError(str(exc)) from None
11✔
2377
            finally:
2378
                scope._tasks.discard(task)
11✔
2379

2380
        loop = cast(AbstractEventLoop, token)
11✔
2381
        context = copy_context()
11✔
2382
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
11✔
2383
        wrapper = task_wrapper(threadlocals.current_cancel_scope)
11✔
2384
        f: concurrent.futures.Future[T_Retval] = context.run(
11✔
2385
            asyncio.run_coroutine_threadsafe, wrapper, loop
2386
        )
2387
        return f.result()
11✔
2388

2389
    @classmethod
11✔
2390
    def run_sync_from_thread(
11✔
2391
        cls,
2392
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2393
        args: tuple[Unpack[PosArgsT]],
2394
        token: object,
2395
    ) -> T_Retval:
2396
        @wraps(func)
11✔
2397
        def wrapper() -> None:
11✔
2398
            try:
11✔
2399
                sniffio.current_async_library_cvar.set("asyncio")
11✔
2400
                f.set_result(func(*args))
11✔
2401
            except BaseException as exc:
11✔
2402
                f.set_exception(exc)
11✔
2403
                if not isinstance(exc, Exception):
11✔
2404
                    raise
×
2405

2406
        f: concurrent.futures.Future[T_Retval] = Future()
11✔
2407
        loop = cast(AbstractEventLoop, token)
11✔
2408
        loop.call_soon_threadsafe(wrapper)
11✔
2409
        return f.result()
11✔
2410

2411
    @classmethod
11✔
2412
    def create_blocking_portal(cls) -> abc.BlockingPortal:
11✔
2413
        return BlockingPortal()
11✔
2414

2415
    @classmethod
11✔
2416
    async def open_process(
11✔
2417
        cls,
2418
        command: StrOrBytesPath | Sequence[StrOrBytesPath],
2419
        *,
2420
        stdin: int | IO[Any] | None,
2421
        stdout: int | IO[Any] | None,
2422
        stderr: int | IO[Any] | None,
2423
        **kwargs: Any,
2424
    ) -> Process:
2425
        await cls.checkpoint()
10✔
2426
        if isinstance(command, PathLike):
10✔
2427
            command = os.fspath(command)
×
2428

2429
        if isinstance(command, (str, bytes)):
10✔
2430
            process = await asyncio.create_subprocess_shell(
10✔
2431
                command,
2432
                stdin=stdin,
2433
                stdout=stdout,
2434
                stderr=stderr,
2435
                **kwargs,
2436
            )
2437
        else:
2438
            process = await asyncio.create_subprocess_exec(
10✔
2439
                *command,
2440
                stdin=stdin,
2441
                stdout=stdout,
2442
                stderr=stderr,
2443
                **kwargs,
2444
            )
2445

2446
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
10✔
2447
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
10✔
2448
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
10✔
2449
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
10✔
2450

2451
    @classmethod
11✔
2452
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
11✔
2453
        create_task(
10✔
2454
            _shutdown_process_pool_on_exit(workers),
2455
            name="AnyIO process pool shutdown task",
2456
        )
2457
        find_root_task().add_done_callback(
10✔
2458
            partial(_forcibly_shutdown_process_pool_on_exit, workers)  # type:ignore[arg-type]
2459
        )
2460

2461
    @classmethod
11✔
2462
    async def connect_tcp(
11✔
2463
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2464
    ) -> abc.SocketStream:
2465
        transport, protocol = cast(
11✔
2466
            Tuple[asyncio.Transport, StreamProtocol],
2467
            await get_running_loop().create_connection(
2468
                StreamProtocol, host, port, local_addr=local_address
2469
            ),
2470
        )
2471
        transport.pause_reading()
11✔
2472
        return SocketStream(transport, protocol)
11✔
2473

2474
    @classmethod
11✔
2475
    async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
11✔
2476
        await cls.checkpoint()
8✔
2477
        loop = get_running_loop()
8✔
2478
        raw_socket = socket.socket(socket.AF_UNIX)
8✔
2479
        raw_socket.setblocking(False)
8✔
2480
        while True:
5✔
2481
            try:
8✔
2482
                raw_socket.connect(path)
8✔
2483
            except BlockingIOError:
8✔
2484
                f: asyncio.Future = asyncio.Future()
×
2485
                loop.add_writer(raw_socket, f.set_result, None)
×
2486
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2487
                await f
×
2488
            except BaseException:
8✔
2489
                raw_socket.close()
8✔
2490
                raise
8✔
2491
            else:
2492
                return UNIXSocketStream(raw_socket)
8✔
2493

2494
    @classmethod
11✔
2495
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
11✔
2496
        return TCPSocketListener(sock)
11✔
2497

2498
    @classmethod
11✔
2499
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
11✔
2500
        return UNIXSocketListener(sock)
8✔
2501

2502
    @classmethod
11✔
2503
    async def create_udp_socket(
11✔
2504
        cls,
2505
        family: AddressFamily,
2506
        local_address: IPSockAddrType | None,
2507
        remote_address: IPSockAddrType | None,
2508
        reuse_port: bool,
2509
    ) -> UDPSocket | ConnectedUDPSocket:
2510
        transport, protocol = await get_running_loop().create_datagram_endpoint(
10✔
2511
            DatagramProtocol,
2512
            local_addr=local_address,
2513
            remote_addr=remote_address,
2514
            family=family,
2515
            reuse_port=reuse_port,
2516
        )
2517
        if protocol.exception:
10✔
2518
            transport.close()
×
2519
            raise protocol.exception
×
2520

2521
        if not remote_address:
10✔
2522
            return UDPSocket(transport, protocol)
10✔
2523
        else:
2524
            return ConnectedUDPSocket(transport, protocol)
10✔
2525

2526
    @classmethod
11✔
2527
    async def create_unix_datagram_socket(  # type: ignore[override]
11✔
2528
        cls, raw_socket: socket.socket, remote_path: str | bytes | None
2529
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2530
        await cls.checkpoint()
8✔
2531
        loop = get_running_loop()
8✔
2532

2533
        if remote_path:
8✔
2534
            while True:
5✔
2535
                try:
8✔
2536
                    raw_socket.connect(remote_path)
8✔
2537
                except BlockingIOError:
×
2538
                    f: asyncio.Future = asyncio.Future()
×
2539
                    loop.add_writer(raw_socket, f.set_result, None)
×
2540
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2541
                    await f
×
2542
                except BaseException:
×
2543
                    raw_socket.close()
×
2544
                    raise
×
2545
                else:
2546
                    return ConnectedUNIXDatagramSocket(raw_socket)
8✔
2547
        else:
2548
            return UNIXDatagramSocket(raw_socket)
8✔
2549

2550
    @classmethod
11✔
2551
    async def getaddrinfo(
11✔
2552
        cls,
2553
        host: bytes | str | None,
2554
        port: str | int | None,
2555
        *,
2556
        family: int | AddressFamily = 0,
2557
        type: int | SocketKind = 0,
2558
        proto: int = 0,
2559
        flags: int = 0,
2560
    ) -> list[
2561
        tuple[
2562
            AddressFamily,
2563
            SocketKind,
2564
            int,
2565
            str,
2566
            tuple[str, int] | tuple[str, int, int, int],
2567
        ]
2568
    ]:
2569
        return await get_running_loop().getaddrinfo(
11✔
2570
            host, port, family=family, type=type, proto=proto, flags=flags
2571
        )
2572

2573
    @classmethod
11✔
2574
    async def getnameinfo(
11✔
2575
        cls, sockaddr: IPSockAddrType, flags: int = 0
2576
    ) -> tuple[str, str]:
2577
        return await get_running_loop().getnameinfo(sockaddr, flags)
10✔
2578

2579
    @classmethod
11✔
2580
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
11✔
2581
        await cls.checkpoint()
×
2582
        try:
×
2583
            read_events = _read_events.get()
×
2584
        except LookupError:
×
2585
            read_events = {}
×
2586
            _read_events.set(read_events)
×
2587

2588
        if read_events.get(sock):
×
2589
            raise BusyResourceError("reading from") from None
×
2590

2591
        loop = get_running_loop()
×
2592
        event = read_events[sock] = asyncio.Event()
×
2593
        loop.add_reader(sock, event.set)
×
2594
        try:
×
2595
            await event.wait()
×
2596
        finally:
2597
            if read_events.pop(sock, None) is not None:
×
2598
                loop.remove_reader(sock)
×
2599
                readable = True
×
2600
            else:
2601
                readable = False
×
2602

2603
        if not readable:
×
2604
            raise ClosedResourceError
×
2605

2606
    @classmethod
11✔
2607
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
11✔
2608
        await cls.checkpoint()
×
2609
        try:
×
2610
            write_events = _write_events.get()
×
2611
        except LookupError:
×
2612
            write_events = {}
×
2613
            _write_events.set(write_events)
×
2614

2615
        if write_events.get(sock):
×
2616
            raise BusyResourceError("writing to") from None
×
2617

2618
        loop = get_running_loop()
×
2619
        event = write_events[sock] = asyncio.Event()
×
2620
        loop.add_writer(sock.fileno(), event.set)
×
2621
        try:
×
2622
            await event.wait()
×
2623
        finally:
2624
            if write_events.pop(sock, None) is not None:
×
2625
                loop.remove_writer(sock)
×
2626
                writable = True
×
2627
            else:
2628
                writable = False
×
2629

2630
        if not writable:
×
2631
            raise ClosedResourceError
×
2632

2633
    @classmethod
11✔
2634
    def current_default_thread_limiter(cls) -> CapacityLimiter:
11✔
2635
        try:
11✔
2636
            return _default_thread_limiter.get()
11✔
2637
        except LookupError:
11✔
2638
            limiter = CapacityLimiter(40)
11✔
2639
            _default_thread_limiter.set(limiter)
11✔
2640
            return limiter
11✔
2641

2642
    @classmethod
11✔
2643
    def open_signal_receiver(
11✔
2644
        cls, *signals: Signals
2645
    ) -> ContextManager[AsyncIterator[Signals]]:
2646
        return _SignalReceiver(signals)
9✔
2647

2648
    @classmethod
11✔
2649
    def get_current_task(cls) -> TaskInfo:
11✔
2650
        return AsyncIOTaskInfo(current_task())  # type: ignore[arg-type]
11✔
2651

2652
    @classmethod
11✔
2653
    def get_running_tasks(cls) -> Sequence[TaskInfo]:
11✔
2654
        return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
11✔
2655

2656
    @classmethod
11✔
2657
    async def wait_all_tasks_blocked(cls) -> None:
11✔
2658
        await cls.checkpoint()
11✔
2659
        this_task = current_task()
11✔
2660
        while True:
7✔
2661
            for task in all_tasks():
11✔
2662
                if task is this_task:
11✔
2663
                    continue
11✔
2664

2665
                waiter = task._fut_waiter  # type: ignore[attr-defined]
11✔
2666
                if waiter is None or waiter.done():
11✔
2667
                    await sleep(0.1)
11✔
2668
                    break
11✔
2669
            else:
2670
                return
11✔
2671

2672
    @classmethod
11✔
2673
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
11✔
2674
        return TestRunner(**options)
11✔
2675

2676

2677
backend_class = AsyncIOBackend
11✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc