• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 7231713521

16 Dec 2023 11:50AM UTC coverage: 90.885% (+0.03%) from 90.859%
7231713521

Pull #652

github

web-flow
Merge 14f150eca into 3a4ec4799
Pull Request #652: Used TypeVarTuple and ParamSpec in several places

68 of 68 new or added lines in 10 files covered. (100.0%)

5 existing lines in 3 files now uncovered.

4467 of 4915 relevant lines covered (90.89%)

8.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.63
/src/anyio/from_thread.py
1
from __future__ import annotations
10✔
2

3
import sys
10✔
4
import threading
10✔
5
from collections.abc import Awaitable, Callable, Generator
10✔
6
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
10✔
7
from contextlib import AbstractContextManager, contextmanager
10✔
8
from inspect import isawaitable
10✔
9
from types import TracebackType
10✔
10
from typing import (
10✔
11
    Any,
12
    AsyncContextManager,
13
    ContextManager,
14
    Generic,
15
    Iterable,
16
    TypeVar,
17
    cast,
18
    overload,
19
)
20

21
from ._core import _eventloop
10✔
22
from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
10✔
23
from ._core._synchronization import Event
10✔
24
from ._core._tasks import CancelScope, create_task_group
10✔
25
from .abc import AsyncBackend
10✔
26
from .abc._tasks import TaskStatus
10✔
27

28
if sys.version_info >= (3, 11):
10✔
29
    from typing import TypeVarTuple, Unpack
4✔
30
else:
31
    from typing_extensions import TypeVarTuple, Unpack
6✔
32

33
T_Retval = TypeVar("T_Retval")
10✔
34
T_co = TypeVar("T_co", covariant=True)
10✔
35
PosArgsT = TypeVarTuple("PosArgsT")
10✔
36

37

38
def run(
10✔
39
    func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
40
) -> T_Retval:
41
    """
42
    Call a coroutine function from a worker thread.
43

44
    :param func: a coroutine function
45
    :param args: positional arguments for the callable
46
    :return: the return value of the coroutine function
47

48
    """
49
    try:
10✔
50
        async_backend = threadlocals.current_async_backend
10✔
51
        token = threadlocals.current_token
10✔
52
    except AttributeError:
10✔
53
        raise RuntimeError(
10✔
54
            "This function can only be run from an AnyIO worker thread"
55
        ) from None
56

57
    return async_backend.run_async_from_thread(func, args, token=token)
10✔
58

59

60
def run_sync(
10✔
61
    func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
62
) -> T_Retval:
63
    """
64
    Call a function in the event loop thread from a worker thread.
65

66
    :param func: a callable
67
    :param args: positional arguments for the callable
68
    :return: the return value of the callable
69

70
    """
71
    try:
10✔
72
        async_backend = threadlocals.current_async_backend
10✔
73
        token = threadlocals.current_token
10✔
74
    except AttributeError:
10✔
75
        raise RuntimeError(
10✔
76
            "This function can only be run from an AnyIO worker thread"
77
        ) from None
78

79
    return async_backend.run_sync_from_thread(func, args, token=token)
10✔
80

81

82
class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
10✔
83
    _enter_future: Future[T_co]
10✔
84
    _exit_future: Future[bool | None]
10✔
85
    _exit_event: Event
10✔
86
    _exit_exc_info: tuple[
10✔
87
        type[BaseException] | None, BaseException | None, TracebackType | None
88
    ] = (None, None, None)
89

90
    def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal):
10✔
91
        self._async_cm = async_cm
10✔
92
        self._portal = portal
10✔
93

94
    async def run_async_cm(self) -> bool | None:
10✔
95
        try:
10✔
96
            self._exit_event = Event()
10✔
97
            value = await self._async_cm.__aenter__()
10✔
98
        except BaseException as exc:
×
99
            self._enter_future.set_exception(exc)
×
100
            raise
×
101
        else:
102
            self._enter_future.set_result(value)
10✔
103

104
        try:
10✔
105
            # Wait for the sync context manager to exit.
106
            # This next statement can raise `get_cancelled_exc_class()` if
107
            # something went wrong in a task group in this async context
108
            # manager.
109
            await self._exit_event.wait()
10✔
110
        finally:
111
            # In case of cancellation, it could be that we end up here before
112
            # `_BlockingAsyncContextManager.__exit__` is called, and an
113
            # `_exit_exc_info` has been set.
114
            result = await self._async_cm.__aexit__(*self._exit_exc_info)
10✔
115
            return result
10✔
116

117
    def __enter__(self) -> T_co:
10✔
118
        self._enter_future = Future()
10✔
119
        self._exit_future = self._portal.start_task_soon(self.run_async_cm)
10✔
120
        return self._enter_future.result()
10✔
121

122
    def __exit__(
10✔
123
        self,
124
        __exc_type: type[BaseException] | None,
125
        __exc_value: BaseException | None,
126
        __traceback: TracebackType | None,
127
    ) -> bool | None:
128
        self._exit_exc_info = __exc_type, __exc_value, __traceback
10✔
129
        self._portal.call(self._exit_event.set)
10✔
130
        return self._exit_future.result()
10✔
131

132

133
class _BlockingPortalTaskStatus(TaskStatus):
10✔
134
    def __init__(self, future: Future):
10✔
135
        self._future = future
10✔
136

137
    def started(self, value: object = None) -> None:
10✔
138
        self._future.set_result(value)
10✔
139

140

141
class BlockingPortal:
10✔
142
    """An object that lets external threads run code in an asynchronous event loop."""
143

144
    def __new__(cls) -> BlockingPortal:
10✔
145
        return get_async_backend().create_blocking_portal()
10✔
146

147
    def __init__(self) -> None:
10✔
148
        self._event_loop_thread_id: int | None = threading.get_ident()
10✔
149
        self._stop_event = Event()
10✔
150
        self._task_group = create_task_group()
10✔
151
        self._cancelled_exc_class = get_cancelled_exc_class()
10✔
152

153
    async def __aenter__(self) -> BlockingPortal:
10✔
154
        await self._task_group.__aenter__()
10✔
155
        return self
10✔
156

157
    async def __aexit__(
10✔
158
        self,
159
        exc_type: type[BaseException] | None,
160
        exc_val: BaseException | None,
161
        exc_tb: TracebackType | None,
162
    ) -> bool | None:
163
        await self.stop()
10✔
164
        return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
10✔
165

166
    def _check_running(self) -> None:
10✔
167
        if self._event_loop_thread_id is None:
10✔
168
            raise RuntimeError("This portal is not running")
10✔
169
        if self._event_loop_thread_id == threading.get_ident():
10✔
170
            raise RuntimeError(
10✔
171
                "This method cannot be called from the event loop thread"
172
            )
173

174
    async def sleep_until_stopped(self) -> None:
10✔
175
        """Sleep until :meth:`stop` is called."""
176
        await self._stop_event.wait()
10✔
177

178
    async def stop(self, cancel_remaining: bool = False) -> None:
10✔
179
        """
180
        Signal the portal to shut down.
181

182
        This marks the portal as no longer accepting new calls and exits from
183
        :meth:`sleep_until_stopped`.
184

185
        :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
186
            to let them finish before returning
187

188
        """
189
        self._event_loop_thread_id = None
10✔
190
        self._stop_event.set()
10✔
191
        if cancel_remaining:
10✔
192
            self._task_group.cancel_scope.cancel()
10✔
193

194
    async def _call_func(
10✔
195
        self,
196
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
197
        args: tuple[Unpack[PosArgsT]],
198
        kwargs: dict[str, Any],
199
        future: Future[T_Retval],
200
    ) -> None:
201
        def callback(f: Future[T_Retval]) -> None:
10✔
202
            if f.cancelled() and self._event_loop_thread_id not in (
10✔
203
                None,
204
                threading.get_ident(),
205
            ):
206
                self.call(scope.cancel)
10✔
207

208
        try:
10✔
209
            retval_or_awaitable = func(*args, **kwargs)
10✔
210
            if isawaitable(retval_or_awaitable):
10✔
211
                with CancelScope() as scope:
10✔
212
                    if future.cancelled():
10✔
213
                        scope.cancel()
9✔
214
                    else:
215
                        future.add_done_callback(callback)
10✔
216

217
                    retval = await retval_or_awaitable
10✔
218
            else:
219
                retval = retval_or_awaitable
10✔
220
        except self._cancelled_exc_class:
10✔
221
            future.cancel()
10✔
222
            future.set_running_or_notify_cancel()
10✔
223
        except BaseException as exc:
10✔
224
            if not future.cancelled():
10✔
225
                future.set_exception(exc)
10✔
226

227
            # Let base exceptions fall through
228
            if not isinstance(exc, Exception):
10✔
229
                raise
10✔
230
        else:
231
            if not future.cancelled():
10✔
232
                future.set_result(retval)
10✔
233
        finally:
234
            scope = None  # type: ignore[assignment]
10✔
235

236
    def _spawn_task_from_thread(
10✔
237
        self,
238
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
239
        args: tuple[Unpack[PosArgsT]],
240
        kwargs: dict[str, Any],
241
        name: object,
242
        future: Future[T_Retval],
243
    ) -> None:
244
        """
245
        Spawn a new task using the given callable.
246

247
        Implementors must ensure that the future is resolved when the task finishes.
248

249
        :param func: a callable
250
        :param args: positional arguments to be passed to the callable
251
        :param kwargs: keyword arguments to be passed to the callable
252
        :param name: name of the task (will be coerced to a string if not ``None``)
253
        :param future: a future that will resolve to the return value of the callable,
254
            or the exception raised during its execution
255

256
        """
257
        raise NotImplementedError
×
258

259
    @overload
10✔
260
    def call(
10✔
261
        self,
262
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
263
        *args: Unpack[PosArgsT],
264
    ) -> T_Retval:
UNCOV
265
        ...
×
266

267
    @overload
10✔
268
    def call(
10✔
269
        self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
270
    ) -> T_Retval:
UNCOV
271
        ...
×
272

273
    def call(
10✔
274
        self,
275
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
276
        *args: Unpack[PosArgsT],
277
    ) -> T_Retval:
278
        """
279
        Call the given function in the event loop thread.
280

281
        If the callable returns a coroutine object, it is awaited on.
282

283
        :param func: any callable
284
        :raises RuntimeError: if the portal is not running or if this method is called
285
            from within the event loop thread
286

287
        """
288
        return cast(T_Retval, self.start_task_soon(func, *args).result())
10✔
289

290
    @overload
10✔
291
    def start_task_soon(
10✔
292
        self,
293
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
294
        *args: Unpack[PosArgsT],
295
        name: object = None,
296
    ) -> Future[T_Retval]:
297
        ...
×
298

299
    @overload
10✔
300
    def start_task_soon(
10✔
301
        self,
302
        func: Callable[[Unpack[PosArgsT]], T_Retval],
303
        *args: Unpack[PosArgsT],
304
        name: object = None,
305
    ) -> Future[T_Retval]:
306
        ...
×
307

308
    def start_task_soon(
10✔
309
        self,
310
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
311
        *args: Unpack[PosArgsT],
312
        name: object = None,
313
    ) -> Future[T_Retval]:
314
        """
315
        Start a task in the portal's task group.
316

317
        The task will be run inside a cancel scope which can be cancelled by cancelling
318
        the returned future.
319

320
        :param func: the target function
321
        :param args: positional arguments passed to ``func``
322
        :param name: name of the task (will be coerced to a string if not ``None``)
323
        :return: a future that resolves with the return value of the callable if the
324
            task completes successfully, or with the exception raised in the task
325
        :raises RuntimeError: if the portal is not running or if this method is called
326
            from within the event loop thread
327
        :rtype: concurrent.futures.Future[T_Retval]
328

329
        .. versionadded:: 3.0
330

331
        """
332
        self._check_running()
10✔
333
        f: Future[T_Retval] = Future()
10✔
334
        self._spawn_task_from_thread(func, args, {}, name, f)
10✔
335
        return f
10✔
336

337
    def start_task(
10✔
338
        self,
339
        func: Callable[..., Awaitable[T_Retval]],
340
        *args: object,
341
        name: object = None,
342
    ) -> tuple[Future[T_Retval], Any]:
343
        """
344
        Start a task in the portal's task group and wait until it signals for readiness.
345

346
        This method works the same way as :meth:`.abc.TaskGroup.start`.
347

348
        :param func: the target function
349
        :param args: positional arguments passed to ``func``
350
        :param name: name of the task (will be coerced to a string if not ``None``)
351
        :return: a tuple of (future, task_status_value) where the ``task_status_value``
352
            is the value passed to ``task_status.started()`` from within the target
353
            function
354
        :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
355

356
        .. versionadded:: 3.0
357

358
        """
359

360
        def task_done(future: Future[T_Retval]) -> None:
10✔
361
            if not task_status_future.done():
10✔
362
                if future.cancelled():
10✔
363
                    task_status_future.cancel()
×
364
                elif future.exception():
10✔
365
                    task_status_future.set_exception(future.exception())
10✔
366
                else:
367
                    exc = RuntimeError(
10✔
368
                        "Task exited without calling task_status.started()"
369
                    )
370
                    task_status_future.set_exception(exc)
10✔
371

372
        self._check_running()
10✔
373
        task_status_future: Future = Future()
10✔
374
        task_status = _BlockingPortalTaskStatus(task_status_future)
10✔
375
        f: Future = Future()
10✔
376
        f.add_done_callback(task_done)
10✔
377
        self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
10✔
378
        return f, task_status_future.result()
10✔
379

380
    def wrap_async_context_manager(
10✔
381
        self, cm: AsyncContextManager[T_co]
382
    ) -> ContextManager[T_co]:
383
        """
384
        Wrap an async context manager as a synchronous context manager via this portal.
385

386
        Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
387
        in the middle until the synchronous context manager exits.
388

389
        :param cm: an asynchronous context manager
390
        :return: a synchronous context manager
391

392
        .. versionadded:: 2.1
393

394
        """
395
        return _BlockingAsyncContextManager(cm, self)
10✔
396

397

398
@contextmanager
10✔
399
def start_blocking_portal(
10✔
400
    backend: str = "asyncio", backend_options: dict[str, Any] | None = None
401
) -> Generator[BlockingPortal, Any, None]:
402
    """
403
    Start a new event loop in a new thread and run a blocking portal in its main task.
404

405
    The parameters are the same as for :func:`~anyio.run`.
406

407
    :param backend: name of the backend
408
    :param backend_options: backend options
409
    :return: a context manager that yields a blocking portal
410

411
    .. versionchanged:: 3.0
412
        Usage as a context manager is now required.
413

414
    """
415

416
    async def run_portal() -> None:
10✔
417
        async with BlockingPortal() as portal_:
10✔
418
            if future.set_running_or_notify_cancel():
10✔
419
                future.set_result(portal_)
10✔
420
                await portal_.sleep_until_stopped()
10✔
421

422
    future: Future[BlockingPortal] = Future()
10✔
423
    with ThreadPoolExecutor(1) as executor:
10✔
424
        run_future = executor.submit(
10✔
425
            _eventloop.run,  # type: ignore[arg-type]
426
            run_portal,
427
            backend=backend,
428
            backend_options=backend_options,
429
        )
430
        try:
10✔
431
            wait(
10✔
432
                cast(Iterable[Future], [run_future, future]),
433
                return_when=FIRST_COMPLETED,
434
            )
435
        except BaseException:
×
436
            future.cancel()
×
437
            run_future.cancel()
×
438
            raise
×
439

440
        if future.done():
10✔
441
            portal = future.result()
10✔
442
            cancel_remaining_tasks = False
10✔
443
            try:
10✔
444
                yield portal
10✔
445
            except BaseException:
×
446
                cancel_remaining_tasks = True
×
447
                raise
×
448
            finally:
449
                try:
10✔
450
                    portal.call(portal.stop, cancel_remaining_tasks)
10✔
451
                except RuntimeError:
9✔
452
                    pass
9✔
453

454
        run_future.result()
10✔
455

456

457
def check_cancelled() -> None:
10✔
458
    """
459
    Check if the cancel scope of the host task's running the current worker thread has
460
    been cancelled.
461

462
    If the host task's current cancel scope has indeed been cancelled, the
463
    backend-specific cancellation exception will be raised.
464

465
    :raises RuntimeError: if the current thread was not spawned by
466
        :func:`.to_thread.run_sync`
467

468
    """
469
    try:
10✔
470
        async_backend: AsyncBackend = threadlocals.current_async_backend
10✔
471
    except AttributeError:
×
472
        raise RuntimeError(
×
473
            "This function can only be run from an AnyIO worker thread"
474
        ) from None
475

476
    async_backend.check_cancelled()
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc