• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / ibm-generative-ai / 8852326016

26 Apr 2024 05:57PM UTC coverage: 89.378% (-6.3%) from 95.632%
8852326016

push

github

web-flow
fix(llama-index): avoid batching in embeddings (#353)

2 of 2 new or added lines in 1 file covered. (100.0%)

273 existing lines in 38 files now uncovered.

4308 of 4820 relevant lines covered (89.38%)

3.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.47
/src/genai/_utils/async_executor.py
1
import asyncio
4✔
2
import logging
4✔
3
from asyncio import AbstractEventLoop, Task
4✔
4
from collections.abc import Iterator
4✔
5
from concurrent.futures import CancelledError, Future
4✔
6
from contextlib import suppress
4✔
7
from typing import Awaitable, Callable, Generator, Generic, Optional, TypeVar, Union
4✔
8

9
from genai._utils.general import first_defined, single_execution
4✔
10
from genai._utils.http_client.httpx_client import AsyncHttpxClient
4✔
11
from genai._utils.limiters.base_limiter import BaseLimiter
4✔
12
from genai._utils.limiters.container_limiter import LimiterContainer
4✔
13
from genai._utils.limiters.shared_limiter import LoopBoundLimiter
4✔
14
from genai._utils.queues.flushable_queue import FlushableQueue
4✔
15
from genai._utils.queues.ordered_queue import OrderedQueue
4✔
16

17
__all__ = ["execute_async", "BaseConfig"]
4✔
18

19
from genai._utils.shared_loop import shared_event_loop
4✔
20

21
TInput = TypeVar("TInput")
4✔
22
TResult = TypeVar("TResult")
4✔
23

24

25
logger = logging.getLogger(__name__)
4✔
26

27

28
class BaseConfig:
4✔
29
    __slots__ = ()
4✔
30
    ordered: bool = False
4✔
31
    throw_on_error: bool = True
4✔
32
    limit_reach_retry_threshold: float = 1.0
4✔
33
    concurrency_limit: Optional[int] = None
4✔
34

35

36
class _AsyncGenerator(Generic[TInput, TResult]):
4✔
37
    """
4✔
38
    Utility iterator to process 'inputs' asynchronously by spawning new thread with it's own event loop.
39
    Communication is done via Queues.
40
    """
41

42
    def __init__(
4✔
43
        self,
44
        *,
45
        inputs: list[TInput],
46
        http_client: Callable[[], AsyncHttpxClient],
47
        handler: Callable[[TInput, AsyncHttpxClient, BaseLimiter], Awaitable[TResult]],
48
        limiters: Optional[list[Optional[LoopBoundLimiter]]] = None,
49
        ordered: Optional[bool] = None,
50
        throw_on_error: Optional[bool] = None,
51
    ):
52
        self._inputs = inputs
4✔
53
        self._http_client_factory = http_client
4✔
54
        self._handler = handler
4✔
55
        self._limiters = limiters
4✔
56
        self._ordered = first_defined(ordered, default=BaseConfig.ordered)
4✔
57
        self._throw_on_error = first_defined(throw_on_error, default=BaseConfig.throw_on_error)
4✔
58

59
        self._queue: Union[OrderedQueue, FlushableQueue] = OrderedQueue() if ordered else FlushableQueue()
4✔
60
        self._irrecoverable_error = False
4✔
61
        self._future: Optional[Future] = None
4✔
62

63
    def _add_to_queue(
4✔
64
        self,
65
        *,
66
        idx: Optional[int],
67
        result: Optional[TResult],
68
        error: Optional[Exception],
69
    ):
70
        entry = idx, result, error
4✔
71
        self._queue.put_nowait(entry)
4✔
72

73
    async def _process_input(
4✔
74
        self,
75
        limiter: BaseLimiter,
76
        batch_num: int,
77
        input: TInput,
78
        client: AsyncHttpxClient,
79
    ):
80
        async with limiter:
4✔
81
            logger.debug(f"Creating task for batch_num: {batch_num}")
4✔
82
            try:
4✔
83
                response = await self._handler(input, client, limiter)
4✔
84
                logger.debug("Received response = {}".format(response))
4✔
85
                self._add_to_queue(idx=batch_num, result=response, error=None)
4✔
86
            except Exception as e:
4✔
87
                logger.error(f"Exception raised during processing\n{str(e)}")
4✔
88
                self._add_to_queue(idx=batch_num, result=None, error=e)
4✔
89

90
    async def _schedule_requests(self, limiter: BaseLimiter, loop: AbstractEventLoop):
4✔
91
        tasks: list[Task] = []
4✔
92
        try:
4✔
93
            async with self._http_client_factory() as client:
4✔
94
                tasks.extend(
4✔
95
                    loop.create_task(self._process_input(limiter, idx, input, client))
96
                    for idx, input in enumerate(self._inputs)
97
                )
98
                await asyncio.gather(*tasks)
4✔
UNCOV
99
        except Exception as ex:
×
UNCOV
100
            self._irrecoverable_error = True
×
UNCOV
101
            self._add_to_queue(idx=None, result=None, error=ex)
×
UNCOV
102
            raise ex
×
103
        finally:
104
            for task in tasks:
4✔
105
                task.cancel()
4✔
106

107
    @single_execution
4✔
108
    def _handle_close_signal(self):
4✔
UNCOV
109
        if not self._future:
×
UNCOV
110
            return
×
UNCOV
111
        self._future.cancel()
×
UNCOV
112
        self._queue.flush()
×
UNCOV
113
        self._irrecoverable_error = True
×
UNCOV
114
        self._add_to_queue(
×
115
            idx=None,
116
            result=None,
117
            error=InterruptedError("Generation has been aborted by the user."),
118
        )
119

120
    def create_iterator(self) -> Iterator[TResult]:
4✔
121
        if not self._inputs:
4✔
122
            return
4✔
123
        with shared_event_loop as loop:
4✔
124
            limiter = LimiterContainer(*(self._limiters or []))
4✔
125
            self._future = asyncio.run_coroutine_threadsafe(self._schedule_requests(limiter, loop), loop)
4✔
126

127
            shared_event_loop.add_close_handler(self._handle_close_signal)
4✔
128
            try:
4✔
129
                for _ in enumerate(self._inputs):
4✔
130
                    batch_num, response, error = self._queue.get()
4✔
131
                    self._queue.task_done()
4✔
132

133
                    if (self._throw_on_error or self._irrecoverable_error) and error:
4✔
134
                        raise error
4✔
135

136
                    yield response
4✔
137
            except Exception:
4✔
138
                # future is PENDING even if there are running tasks due to implementation of run_coroutine_threadsafe
139
                # therefore cancel always succeeds
140
                self._future.cancel()
4✔
141
                raise
4✔
142
            finally:
143
                with suppress(CancelledError):
4✔
144
                    self._future.result()
4✔
145
                self._queue.flush()
4✔
146
                shared_event_loop.remove_close_handler(self._handle_close_signal)
4✔
147

148
    def __iter__(self) -> Iterator[TResult]:
4✔
149
        return self.create_iterator()
4✔
150

151

152
def execute_async(
4✔
153
    *,
154
    inputs: list[TInput],
155
    http_client: Callable[[], AsyncHttpxClient],
156
    handler: Callable[[TInput, AsyncHttpxClient, BaseLimiter], Awaitable[TResult]],
157
    limiters: Optional[list[Optional[LoopBoundLimiter]]] = None,
158
    ordered: Optional[bool] = None,
159
    throw_on_error: Optional[bool] = None,
160
) -> Generator[TResult, None, None]:
161
    yield from _AsyncGenerator(
4✔
162
        inputs=inputs,
163
        http_client=http_client,
164
        handler=handler,
165
        limiters=limiters,
166
        ordered=ordered,
167
        throw_on_error=throw_on_error,
168
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc