• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18530018322

15 Oct 2025 01:10PM UTC coverage: 92.075% (-0.03%) from 92.103%
18530018322

Pull #9880

github

web-flow
Merge 6dad544fe into cfa5d2761
Pull Request #9880: draft: Expand tools param to include list[Toolset]

13279 of 14422 relevant lines covered (92.07%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.86
haystack/components/generators/chat/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import re
1✔
8
import sys
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from contextlib import asynccontextmanager, suppress
1✔
11
from typing import Any, Callable, Literal, Optional, Union
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
1✔
15
from haystack.dataclasses.streaming_chunk import select_streaming_callback
1✔
16
from haystack.lazy_imports import LazyImport
1✔
17
from haystack.tools import (
1✔
18
    Tool,
19
    Toolset,
20
    _check_duplicate_tool_names,
21
    deserialize_tools_or_toolset_inplace,
22
    flatten_tools_or_toolsets,
23
    serialize_tools_or_toolset,
24
)
25
from haystack.utils import (
1✔
26
    ComponentDevice,
27
    Secret,
28
    deserialize_callable,
29
    deserialize_secrets_inplace,
30
    serialize_callable,
31
)
32

33
logger = logging.getLogger(__name__)
1✔
34

35
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
1✔
36
    from huggingface_hub import model_info
1✔
37
    from transformers import Pipeline as HfPipeline
1✔
38
    from transformers import StoppingCriteriaList, pipeline
1✔
39
    from transformers.tokenization_utils import PreTrainedTokenizer
1✔
40
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1✔
41

42
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
43
        AsyncHFTokenStreamingHandler,
44
        HFTokenStreamingHandler,
45
        StopWordsCriteria,
46
        convert_message_to_hf_format,
47
        deserialize_hf_model_kwargs,
48
        serialize_hf_model_kwargs,
49
    )
50

51

52
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
53

54
DEFAULT_TOOL_PATTERN = (
1✔
55
    r"(?:<tool_call>)?"
56
    r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
57
    r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
58
)
59

60

61
def default_tool_parser(text: str) -> Optional[list[ToolCall]]:
1✔
62
    """
63
    Default implementation for parsing tool calls from model output text.
64

65
    Uses DEFAULT_TOOL_PATTERN to extract tool calls.
66

67
    :param text: The text to parse for tool calls.
68
    :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
69
    """
70
    try:
1✔
71
        match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
1✔
72
    except re.error:
×
73
        logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
×
74
        return None
×
75

76
    if not match:
1✔
77
        return None
×
78

79
    name = match.group(1) or match.group(3)
1✔
80
    args_str = match.group(2) or match.group(4)
1✔
81

82
    try:
1✔
83
        arguments = json.loads(args_str)
1✔
84
        return [ToolCall(tool_name=name, arguments=arguments)]
1✔
85
    except json.JSONDecodeError:
×
86
        logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
×
87
        return None
×
88

89

90
@component
1✔
91
class HuggingFaceLocalChatGenerator:
1✔
92
    """
93
    Generates chat responses using models from Hugging Face that run locally.
94

95
    Use this component with chat-based models,
96
    such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`.
97
    LLMs running locally may need powerful hardware.
98

99
    ### Usage example
100

101
    ```python
102
    from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
103
    from haystack.dataclasses import ChatMessage
104

105
    generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
106
    generator.warm_up()
107
    messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
108
    print(generator.run(messages))
109
    ```
110

111
    ```
112
    {'replies':
113
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
114
        "Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
115
        with the interaction between computers and human language. It enables computers to understand, interpret, and
116
        generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
117
        analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
118
        process and derive meaning from human language, improving communication between humans and machines.")],
119
        _name=None,
120
        _meta={'finish_reason': 'stop', 'index': 0, 'model':
121
              'mistralai/Mistral-7B-Instruct-v0.2',
122
              'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
123
              ]
124
    }
125
    ```
126
    """
127

128
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
129
        self,
130
        model: str = "HuggingFaceH4/zephyr-7b-beta",
131
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
132
        device: Optional[ComponentDevice] = None,
133
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
134
        chat_template: Optional[str] = None,
135
        generation_kwargs: Optional[dict[str, Any]] = None,
136
        huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None,
137
        stop_words: Optional[list[str]] = None,
138
        streaming_callback: Optional[StreamingCallbackT] = None,
139
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
140
        tool_parsing_function: Optional[Callable[[str], Optional[list[ToolCall]]]] = None,
141
        async_executor: Optional[ThreadPoolExecutor] = None,
142
    ) -> None:
143
        """
144
        Initializes the HuggingFaceLocalChatGenerator component.
145

146
        :param model: The Hugging Face text generation model name or path,
147
            for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
148
            The model must be a chat model supporting the ChatML messaging
149
            format.
150
            If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
151
        :param task: The task for the Hugging Face pipeline. Possible options:
152
            - `text-generation`: Supported by decoder models, like GPT.
153
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
154
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
155
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
156
        :param device: The device for loading the model. If `None`, automatically selects the default device.
157
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
158
        :param token: The token to use as HTTP bearer authorization for remote files.
159
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
160
        :param chat_template: Specifies an optional Jinja template for formatting chat
161
            messages. Most high-quality chat models have their own templates, but for models without this
162
            feature or if you prefer a custom template, use this parameter.
163
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
164
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
165
            See Hugging Face's documentation for more information:
166
            - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
167
            - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
168
            The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
169
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
170
            Hugging Face pipeline for text generation.
171
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
172
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
173
            For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
174
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
175
        :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
176
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
177
            For some chat models, the output includes both the new text and the original prompt.
178
            In these cases, make sure your prompt has no stop words.
179
        :param streaming_callback: An optional callable for handling streaming responses.
180
        :param tools: A list of tools, a Toolset, or a list of Toolset instances for which the model can prepare calls.
181
        :param tool_parsing_function:
182
            A callable that takes a string and returns a list of ToolCall objects or None.
183
            If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
184
        :param async_executor:
185
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
186
            initialized and used
187
        """
188
        torch_and_transformers_import.check()
1✔
189

190
        if tools and streaming_callback is not None:
1✔
191
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
192
        _check_duplicate_tool_names(flatten_tools_or_toolsets(tools or []))
1✔
193

194
        huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
1✔
195
        generation_kwargs = generation_kwargs or {}
1✔
196

197
        self.token = token
1✔
198
        token = token.resolve_value() if token else None
1✔
199

200
        # check if the huggingface_pipeline_kwargs contain the essential parameters
201
        # otherwise, populate them with values from other init parameters
202
        huggingface_pipeline_kwargs.setdefault("model", model)
1✔
203
        huggingface_pipeline_kwargs.setdefault("token", token)
1✔
204

205
        device = ComponentDevice.resolve_device(device)
1✔
206
        device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
207

208
        # task identification and validation
209
        if task is None:
1✔
210
            if "task" in huggingface_pipeline_kwargs:
1✔
211
                task = huggingface_pipeline_kwargs["task"]
1✔
212
            elif isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
213
                task = model_info(
1✔
214
                    huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
215
                ).pipeline_tag  # type: ignore[assignment]  # we'll check below if task is in supported tasks
216

217
        if task not in PIPELINE_SUPPORTED_TASKS:
1✔
218
            raise ValueError(
1✔
219
                f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
220
            )
221
        huggingface_pipeline_kwargs["task"] = task
1✔
222

223
        # if not specified, set return_full_text to False for text-generation
224
        # only generated text is returned (excluding prompt)
225
        if task == "text-generation":
1✔
226
            generation_kwargs.setdefault("return_full_text", False)
×
227

228
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
229
            raise ValueError(
×
230
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
231
                "Please specify only one of them."
232
            )
233
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
234
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
235
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
236

237
        self.tool_parsing_function = tool_parsing_function or default_tool_parser
1✔
238
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
239
        self.generation_kwargs = generation_kwargs
1✔
240
        self.chat_template = chat_template
1✔
241
        self.streaming_callback = streaming_callback
1✔
242
        self.pipeline: Optional[HfPipeline] = None
1✔
243
        self.tools = tools
1✔
244

245
        self._owns_executor = async_executor is None
1✔
246
        self.executor = (
1✔
247
            ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
248
            if async_executor is None
249
            else async_executor
250
        )
251

252
    def __del__(self) -> None:
1✔
253
        """
254
        Cleanup when the instance is being destroyed.
255
        """
256
        if self._owns_executor:
1✔
257
            self.executor.shutdown(wait=True)
1✔
258

259
    def shutdown(self) -> None:
1✔
260
        """
261
        Explicitly shutdown the executor if we own it.
262
        """
263
        if self._owns_executor:
×
264
            self.executor.shutdown(wait=True)
×
265

266
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
267
        """
268
        Data that is sent to Posthog for usage analytics.
269
        """
270
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
271
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
272
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
273

274
    def warm_up(self) -> None:
1✔
275
        """
276
        Initializes the component.
277
        """
278
        if self.pipeline is None:
1✔
279
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
280

281
    def to_dict(self) -> dict[str, Any]:
1✔
282
        """
283
        Serializes the component to a dictionary.
284

285
        :returns:
286
            Dictionary with serialized data.
287
        """
288
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
289
        serialization_dict = default_to_dict(
1✔
290
            self,
291
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
292
            generation_kwargs=self.generation_kwargs,
293
            streaming_callback=callback_name,
294
            token=self.token.to_dict() if self.token else None,
295
            chat_template=self.chat_template,
296
            tools=serialize_tools_or_toolset(self.tools),
297
            tool_parsing_function=serialize_callable(self.tool_parsing_function),
298
        )
299

300
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
301
        huggingface_pipeline_kwargs.pop("token", None)
1✔
302

303
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
304
        return serialization_dict
1✔
305

306
    @classmethod
1✔
307
    def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
1✔
308
        """
309
        Deserializes the component from a dictionary.
310

311
        :param data:
312
            The dictionary to deserialize from.
313
        :returns:
314
            The deserialized component.
315
        """
316
        torch_and_transformers_import.check()  # leave this, cls method
1✔
317
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
318
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
319
        init_params = data.get("init_parameters", {})
1✔
320
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
321
        if serialized_callback_handler:
1✔
322
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
323

324
        tool_parsing_function = init_params.get("tool_parsing_function")
1✔
325
        if tool_parsing_function:
1✔
326
            init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
1✔
327

328
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
329
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
330
        return default_from_dict(cls, data)
1✔
331

332
    @component.output_types(replies=list[ChatMessage])
1✔
333
    def run(
1✔
334
        self,
335
        messages: list[ChatMessage],
336
        generation_kwargs: Optional[dict[str, Any]] = None,
337
        streaming_callback: Optional[StreamingCallbackT] = None,
338
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
339
    ) -> dict[str, list[ChatMessage]]:
340
        """
341
        Invoke text generation inference based on the provided messages and generation parameters.
342

343
        :param messages: A list of ChatMessage objects representing the input messages.
344
        :param generation_kwargs: Additional keyword arguments for text generation.
345
        :param streaming_callback: An optional callable for handling streaming responses.
346
        :param tools: A list of tools, a Toolset, or a list of Toolset instances for which the model can prepare calls.
347
            If set, it will override the `tools` parameter provided during initialization.
348
        :returns: A dictionary with the following keys:
349
            - `replies`: A list containing the generated responses as ChatMessage instances.
350
        """
351
        prepared_inputs = self._prepare_inputs(
1✔
352
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
353
        )
354

355
        streaming_callback = select_streaming_callback(
1✔
356
            self.streaming_callback, streaming_callback, requires_async=False
357
        )
358
        if streaming_callback:
1✔
359
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
360
            prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
1✔
361
                tokenizer=prepared_inputs["tokenizer"],
362
                stream_handler=streaming_callback,
363
                stop_words=prepared_inputs["stop_words"],
364
                component_info=ComponentInfo.from_component(self),
365
            )
366

367
        # We know it's not None because we check it in _prepare_inputs
368
        assert self.pipeline is not None
1✔
369
        # Generate responses
370
        output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
1✔
371

372
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
373

374
        return {"replies": chat_messages}
1✔
375

376
    def create_message(  # pylint: disable=too-many-positional-arguments
1✔
377
        self,
378
        text: str,
379
        index: int,
380
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
381
        prompt: str,
382
        generation_kwargs: dict[str, Any],
383
        parse_tool_calls: bool = False,
384
    ) -> ChatMessage:
385
        """
386
        Create a ChatMessage instance from the provided text, populated with metadata.
387

388
        :param text: The generated text.
389
        :param index: The index of the generated text.
390
        :param tokenizer: The tokenizer used for generation.
391
        :param prompt: The prompt used for generation.
392
        :param generation_kwargs: The generation parameters.
393
        :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
394
        :returns: A ChatMessage instance.
395
        """
396

397
        completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
1✔
398
        prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
1✔
399
        total_tokens = prompt_token_count + completion_tokens
1✔
400

401
        tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
1✔
402

403
        # Determine finish reason based on context
404
        if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
1✔
405
            finish_reason = "length"
×
406
        elif tool_calls:
1✔
407
            finish_reason = "tool_calls"
1✔
408
        else:
409
            finish_reason = "stop"
1✔
410

411
        meta = {
1✔
412
            "finish_reason": finish_reason,
413
            "index": index,
414
            "model": self.huggingface_pipeline_kwargs["model"],
415
            "usage": {
416
                "completion_tokens": completion_tokens,
417
                "prompt_tokens": prompt_token_count,
418
                "total_tokens": total_tokens,
419
            },
420
        }
421

422
        # If tool calls are detected, don't include the text content since it contains the raw tool call format
423
        return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
1✔
424

425
    @staticmethod
1✔
426
    def _validate_stop_words(stop_words: Optional[list[str]]) -> Optional[list[str]]:
1✔
427
        """
428
        Validates the provided stop words.
429

430
        :param stop_words: A list of stop words to validate.
431
        :return: A sanitized list of stop words or None if validation fails.
432
        """
433
        if stop_words and not all(isinstance(word, str) for word in stop_words):
1✔
434
            logger.warning(
×
435
                "Invalid stop words provided. Stop words must be specified as a list of strings. "
436
                "Ignoring stop words: {stop_words}",
437
                stop_words=stop_words,
438
            )
439
            return None
×
440

441
        return list(set(stop_words or []))
1✔
442

443
    @component.output_types(replies=list[ChatMessage])
1✔
444
    async def run_async(
1✔
445
        self,
446
        messages: list[ChatMessage],
447
        generation_kwargs: Optional[dict[str, Any]] = None,
448
        streaming_callback: Optional[StreamingCallbackT] = None,
449
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
450
    ) -> dict[str, list[ChatMessage]]:
451
        """
452
        Asynchronously invokes text generation inference based on the provided messages and generation parameters.
453

454
        This is the asynchronous version of the `run` method. It has the same parameters
455
        and return values but can be used with `await` in an async code.
456

457
        :param messages: A list of ChatMessage objects representing the input messages.
458
        :param generation_kwargs: Additional keyword arguments for text generation.
459
        :param streaming_callback: An optional callable for handling streaming responses.
460
        :param tools: A list of tools or a Toolset for which the model can prepare calls.
461
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
462
        :returns: A dictionary with the following keys:
463
            - `replies`: A list containing the generated responses as ChatMessage instances.
464
        """
465
        prepared_inputs = self._prepare_inputs(
1✔
466
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
467
        )
468

469
        # Validate and select the streaming callback
470
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
471

472
        if streaming_callback:
1✔
473
            async_handler = AsyncHFTokenStreamingHandler(
1✔
474
                tokenizer=prepared_inputs["tokenizer"],
475
                stream_handler=streaming_callback,
476
                stop_words=prepared_inputs["stop_words"],
477
                component_info=ComponentInfo.from_component(self),
478
            )
479
            prepared_inputs["generation_kwargs"]["streamer"] = async_handler
1✔
480

481
            # Use async context manager for proper resource cleanup
482
            async with self._manage_queue_processor(async_handler):
1✔
483
                output = await asyncio.get_running_loop().run_in_executor(
1✔
484
                    self.executor,
485
                    # have to ignore since assert self.pipeline is not None doesn't work
486
                    lambda: self.pipeline(  # type: ignore[misc]
487
                        prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
488
                    ),
489
                )
490
        else:
491
            output = await asyncio.get_running_loop().run_in_executor(
1✔
492
                self.executor,
493
                # have to ignore since assert self.pipeline is not None doesn't work
494
                lambda: self.pipeline(  # type: ignore[misc]
495
                    prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
496
                ),
497
            )
498

499
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
500
        return {"replies": chat_messages}
1✔
501

502
    @asynccontextmanager
1✔
503
    async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
1✔
504
        """Context manager for proper queue processor lifecycle management."""
505
        queue_processor = asyncio.create_task(async_handler.process_queue())
1✔
506
        try:
1✔
507
            yield queue_processor
1✔
508
        finally:
509
            # Ensure the queue processor is cleaned up properly
510
            try:
1✔
511
                await asyncio.wait_for(queue_processor, timeout=0.1)
1✔
512
            except asyncio.TimeoutError:
×
513
                queue_processor.cancel()
×
514
                with suppress(asyncio.CancelledError):
×
515
                    await queue_processor
×
516

517
    def _prepare_inputs(
1✔
518
        self,
519
        messages: list[ChatMessage],
520
        generation_kwargs: Optional[dict[str, Any]] = None,
521
        streaming_callback: Optional[StreamingCallbackT] = None,
522
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
523
    ) -> dict[str, Any]:
524
        """
525
        Prepares the inputs for the Hugging Face pipeline.
526

527
        :param messages: A list of ChatMessage objects representing the input messages.
528
        :param generation_kwargs: Additional keyword arguments for text generation.
529
        :param streaming_callback: An optional callable for handling streaming responses.
530
        :param tools: A list of tools or a Toolset for which the model can prepare calls.
531
        :returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
532
        :raises RuntimeError: If the generation model has not been loaded.
533
        :raises ValueError: If both tools and streaming_callback are provided.
534
        """
535
        if self.pipeline is None:
1✔
536
            raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
1✔
537

538
        tools = tools or self.tools
1✔
539
        if tools and streaming_callback is not None:
1✔
540
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
541
        flat_tools = flatten_tools_or_toolsets(tools or [])
1✔
542
        _check_duplicate_tool_names(flat_tools)
1✔
543

544
        tokenizer = self.pipeline.tokenizer
1✔
545
        # initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
546
        assert tokenizer is not None
1✔
547

548
        # Check and update generation parameters
549
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
550

551
        # If streaming_callback is provided, ensure that num_return_sequences is set to 1
552
        if streaming_callback:
1✔
553
            num_responses = generation_kwargs.get("num_return_sequences", 1)
1✔
554
            if num_responses > 1:
1✔
555
                msg = (
×
556
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
557
                    "Streaming is only supported for single response generation. "
558
                    "Setting the number of responses to 1."
559
                )
560
                logger.warning(msg, num_responses=num_responses)
×
561
                generation_kwargs["num_return_sequences"] = 1
×
562

563
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
564
        stop_words = self._validate_stop_words(stop_words)
1✔
565

566
        # Set up stop words criteria if stop words exist
567
        stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
1✔
568
        if stop_words_criteria:
1✔
569
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
570

571
        # convert messages to HF format
572
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
573

574
        prepared_prompt = tokenizer.apply_chat_template(
1✔
575
            hf_messages,
576
            tokenize=False,
577
            chat_template=self.chat_template,
578
            add_generation_prompt=True,
579
            tools=[tc.tool_spec for tc in flat_tools] if flat_tools else None,
580
        )
581
        # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
582
        assert isinstance(prepared_prompt, str)
1✔
583

584
        # Avoid some unnecessary warnings in the generation pipeline call
585
        generation_kwargs["pad_token_id"] = (
1✔
586
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
587
        )
588

589
        return {
1✔
590
            "prepared_prompt": prepared_prompt,
591
            "tokenizer": tokenizer,
592
            "generation_kwargs": generation_kwargs,
593
            "tools": flat_tools,
594
            "stop_words": stop_words,
595
        }
596

597
    def _convert_hf_output_to_chat_messages(
1✔
598
        self,
599
        *,
600
        hf_pipeline_output: list[dict[str, Any]],
601
        prepared_prompt: str,
602
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
603
        generation_kwargs: dict[str, Any],
604
        stop_words: Optional[list[str]],
605
        tools: Optional[Union[list[Tool], Toolset]] = None,
606
    ) -> list[ChatMessage]:
607
        """
608
        Converts the HuggingFace pipeline output into a List of ChatMessages
609

610
        :param hf_pipeline_output: The output from the HuggingFace pipeline.
611
        :param prepared_prompt: The prompt used for generation.
612
        :param tokenizer: The tokenizer used for generation.
613
        :param generation_kwargs: The generation parameters.
614
        :param stop_words: A list of stop words to remove from the replies.
615
        :param tools: A list of tools or a Toolset for which the model can prepare calls.
616
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
617
        """
618
        replies = [o.get("generated_text", "") for o in hf_pipeline_output]
1✔
619

620
        # Remove stop words from replies if present
621
        if stop_words:
1✔
622
            for stop_word in stop_words:
×
623
                replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
624

625
        chat_messages = [
1✔
626
            self.create_message(
627
                text=reply,
628
                index=r_index,
629
                tokenizer=tokenizer,
630
                prompt=prepared_prompt,
631
                generation_kwargs=generation_kwargs,
632
                parse_tool_calls=bool(tools),
633
            )
634
            for r_index, reply in enumerate(replies)
635
        ]
636
        return chat_messages
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc