• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 20135386210

11 Dec 2025 01:51PM UTC coverage: 92.133%. Remained the same
20135386210

Pull #10156

github

web-flow
Merge 3d42fa2f2 into 63de06bdf
Pull Request #10156: chore: Update code snippets in docs (audio and builders components)

14124 of 15330 relevant lines covered (92.13%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.51
haystack/components/generators/chat/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import re
1✔
8
import sys
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from contextlib import asynccontextmanager, suppress
1✔
11
from typing import Any, Callable, Literal, Optional, Union
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
1✔
15
from haystack.dataclasses.streaming_chunk import select_streaming_callback
1✔
16
from haystack.lazy_imports import LazyImport
1✔
17
from haystack.tools import (
1✔
18
    Tool,
19
    Toolset,
20
    ToolsType,
21
    _check_duplicate_tool_names,
22
    deserialize_tools_or_toolset_inplace,
23
    flatten_tools_or_toolsets,
24
    serialize_tools_or_toolset,
25
)
26
from haystack.tools.utils import warm_up_tools
1✔
27
from haystack.utils import (
1✔
28
    ComponentDevice,
29
    Secret,
30
    deserialize_callable,
31
    deserialize_secrets_inplace,
32
    serialize_callable,
33
)
34

35
logger = logging.getLogger(__name__)
1✔
36

37
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
1✔
38
    from huggingface_hub import model_info
1✔
39
    from transformers import Pipeline as HfPipeline
1✔
40
    from transformers import StoppingCriteriaList, pipeline
1✔
41
    from transformers.tokenization_utils import PreTrainedTokenizer
1✔
42
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1✔
43

44
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
45
        AsyncHFTokenStreamingHandler,
46
        HFTokenStreamingHandler,
47
        StopWordsCriteria,
48
        convert_message_to_hf_format,
49
        deserialize_hf_model_kwargs,
50
        serialize_hf_model_kwargs,
51
    )
52

53

54
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
55

56
DEFAULT_TOOL_PATTERN = (
1✔
57
    r"(?:<tool_call>)?"
58
    r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
59
    r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
60
)
61

62

63
def default_tool_parser(text: str) -> Optional[list[ToolCall]]:
1✔
64
    """
65
    Default implementation for parsing tool calls from model output text.
66

67
    Uses DEFAULT_TOOL_PATTERN to extract tool calls.
68

69
    :param text: The text to parse for tool calls.
70
    :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
71
    """
72
    try:
1✔
73
        match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
1✔
74
    except re.error:
×
75
        logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
×
76
        return None
×
77

78
    if not match:
1✔
79
        return None
×
80

81
    name = match.group(1) or match.group(3)
1✔
82
    args_str = match.group(2) or match.group(4)
1✔
83

84
    try:
1✔
85
        arguments = json.loads(args_str)
1✔
86
        return [ToolCall(tool_name=name, arguments=arguments)]
1✔
87
    except json.JSONDecodeError:
×
88
        logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
×
89
        return None
×
90

91

92
@component
1✔
93
class HuggingFaceLocalChatGenerator:
1✔
94
    """
95
    Generates chat responses using models from Hugging Face that run locally.
96

97
    Use this component with chat-based models,
98
    such as `Qwen/Qwen3-0.6B` or `meta-llama/Llama-2-7b-chat-hf`.
99
    LLMs running locally may need powerful hardware.
100

101
    ### Usage example
102

103
    ```python
104
    from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
105
    from haystack.dataclasses import ChatMessage
106

107
    generator = HuggingFaceLocalChatGenerator(model="Qwen/Qwen3-0.6B")
108
    generator.warm_up()
109
    messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
110
    print(generator.run(messages))
111
    ```
112

113
    ```
114
    {'replies':
115
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
116
        "Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
117
        with the interaction between computers and human language. It enables computers to understand, interpret, and
118
        generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
119
        analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
120
        process and derive meaning from human language, improving communication between humans and machines.")],
121
        _name=None,
122
        _meta={'finish_reason': 'stop', 'index': 0, 'model':
123
              'mistralai/Mistral-7B-Instruct-v0.2',
124
              'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
125
              ]
126
    }
127
    ```
128
    """
129

130
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
131
        self,
132
        model: str = "Qwen/Qwen3-0.6B",
133
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
134
        device: Optional[ComponentDevice] = None,
135
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
136
        chat_template: Optional[str] = None,
137
        generation_kwargs: Optional[dict[str, Any]] = None,
138
        huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None,
139
        stop_words: Optional[list[str]] = None,
140
        streaming_callback: Optional[StreamingCallbackT] = None,
141
        tools: Optional[ToolsType] = None,
142
        tool_parsing_function: Optional[Callable[[str], Optional[list[ToolCall]]]] = None,
143
        async_executor: Optional[ThreadPoolExecutor] = None,
144
        *,
145
        enable_thinking: bool = False,
146
    ) -> None:
147
        """
148
        Initializes the HuggingFaceLocalChatGenerator component.
149

150
        :param model: The Hugging Face text generation model name or path,
151
            for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
152
            The model must be a chat model supporting the ChatML messaging
153
            format.
154
            If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
155
        :param task: The task for the Hugging Face pipeline. Possible options:
156
            - `text-generation`: Supported by decoder models, like GPT.
157
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
158
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
159
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
160
        :param device: The device for loading the model. If `None`, automatically selects the default device.
161
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
162
        :param token: The token to use as HTTP bearer authorization for remote files.
163
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
164
        :param chat_template: Specifies an optional Jinja template for formatting chat
165
            messages. Most high-quality chat models have their own templates, but for models without this
166
            feature or if you prefer a custom template, use this parameter.
167
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
168
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
169
            See Hugging Face's documentation for more information:
170
            - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
171
            - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
172
            The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
173
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
174
            Hugging Face pipeline for text generation.
175
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
176
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
177
            For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
178
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
179
        :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
180
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
181
            For some chat models, the output includes both the new text and the original prompt.
182
            In these cases, make sure your prompt has no stop words.
183
        :param streaming_callback: An optional callable for handling streaming responses.
184
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
185
        :param tool_parsing_function:
186
            A callable that takes a string and returns a list of ToolCall objects or None.
187
            If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
188
        :param async_executor:
189
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
190
            initialized and used
191
        :param enable_thinking:
192
            Whether to enable thinking mode in the chat template for thinking-capable models.
193
            When enabled, the model generates intermediate reasoning before the final response. Defaults to False.
194
        """
195
        torch_and_transformers_import.check()
1✔
196

197
        if tools and streaming_callback is not None:
1✔
198
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
199
        _check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
1✔
200

201
        huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
1✔
202
        generation_kwargs = generation_kwargs or {}
1✔
203

204
        self.token = token
1✔
205
        token = token.resolve_value() if token else None
1✔
206

207
        # check if the huggingface_pipeline_kwargs contain the essential parameters
208
        # otherwise, populate them with values from other init parameters
209
        huggingface_pipeline_kwargs.setdefault("model", model)
1✔
210
        huggingface_pipeline_kwargs.setdefault("token", token)
1✔
211

212
        device = ComponentDevice.resolve_device(device)
1✔
213
        device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
214

215
        # task identification and validation
216
        if task is None:
1✔
217
            if "task" in huggingface_pipeline_kwargs:
1✔
218
                task = huggingface_pipeline_kwargs["task"]
1✔
219
            elif isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
220
                task = model_info(
1✔
221
                    huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
222
                ).pipeline_tag  # type: ignore[assignment]  # we'll check below if task is in supported tasks
223

224
        if task not in PIPELINE_SUPPORTED_TASKS:
1✔
225
            raise ValueError(
1✔
226
                f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
227
            )
228
        huggingface_pipeline_kwargs["task"] = task
1✔
229

230
        # if not specified, set return_full_text to False for text-generation
231
        # only generated text is returned (excluding prompt)
232
        if task == "text-generation":
1✔
233
            generation_kwargs.setdefault("return_full_text", False)
×
234

235
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
236
            raise ValueError(
×
237
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
238
                "Please specify only one of them."
239
            )
240
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
241
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
242
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
243

244
        self.tool_parsing_function = tool_parsing_function or default_tool_parser
1✔
245
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
246
        self.generation_kwargs = generation_kwargs
1✔
247
        self.chat_template = chat_template
1✔
248
        self.streaming_callback = streaming_callback
1✔
249
        self.pipeline: Optional[HfPipeline] = None
1✔
250
        self.tools = tools
1✔
251
        self.enable_thinking = enable_thinking
1✔
252

253
        self._owns_executor = async_executor is None
1✔
254
        self.executor = (
1✔
255
            ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
256
            if async_executor is None
257
            else async_executor
258
        )
259
        self._is_warmed_up = False
1✔
260

261
    def __del__(self) -> None:
1✔
262
        """
263
        Cleanup when the instance is being destroyed.
264
        """
265
        if self._owns_executor:
1✔
266
            self.executor.shutdown(wait=True)
1✔
267

268
    def shutdown(self) -> None:
1✔
269
        """
270
        Explicitly shutdown the executor if we own it.
271
        """
272
        if self._owns_executor:
×
273
            self.executor.shutdown(wait=True)
×
274

275
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
276
        """
277
        Data that is sent to Posthog for usage analytics.
278
        """
279
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
280
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
281
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
282

283
    def warm_up(self) -> None:
1✔
284
        """
285
        Initializes the component and warms up tools if provided.
286
        """
287
        if self._is_warmed_up:
1✔
288
            return
1✔
289

290
        # Initialize the pipeline (existing logic)
291
        if self.pipeline is None:
1✔
292
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
293

294
        # Warm up tools (new logic)
295
        if self.tools:
1✔
296
            warm_up_tools(self.tools)
1✔
297

298
        self._is_warmed_up = True
1✔
299

300
    def to_dict(self) -> dict[str, Any]:
1✔
301
        """
302
        Serializes the component to a dictionary.
303

304
        :returns:
305
            Dictionary with serialized data.
306
        """
307
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
308
        serialization_dict = default_to_dict(
1✔
309
            self,
310
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
311
            generation_kwargs=self.generation_kwargs,
312
            streaming_callback=callback_name,
313
            token=self.token.to_dict() if self.token else None,
314
            chat_template=self.chat_template,
315
            tools=serialize_tools_or_toolset(self.tools),
316
            tool_parsing_function=serialize_callable(self.tool_parsing_function),
317
            enable_thinking=self.enable_thinking,
318
        )
319

320
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
321
        huggingface_pipeline_kwargs.pop("token", None)
1✔
322

323
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
324
        return serialization_dict
1✔
325

326
    @classmethod
1✔
327
    def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
1✔
328
        """
329
        Deserializes the component from a dictionary.
330

331
        :param data:
332
            The dictionary to deserialize from.
333
        :returns:
334
            The deserialized component.
335
        """
336
        torch_and_transformers_import.check()  # leave this, cls method
1✔
337
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
338
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
339
        init_params = data.get("init_parameters", {})
1✔
340
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
341
        if serialized_callback_handler:
1✔
342
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
343

344
        tool_parsing_function = init_params.get("tool_parsing_function")
1✔
345
        if tool_parsing_function:
1✔
346
            init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
1✔
347

348
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
349
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
350
        return default_from_dict(cls, data)
1✔
351

352
    @component.output_types(replies=list[ChatMessage])
1✔
353
    def run(
1✔
354
        self,
355
        messages: list[ChatMessage],
356
        generation_kwargs: Optional[dict[str, Any]] = None,
357
        streaming_callback: Optional[StreamingCallbackT] = None,
358
        tools: Optional[ToolsType] = None,
359
    ) -> dict[str, list[ChatMessage]]:
360
        """
361
        Invoke text generation inference based on the provided messages and generation parameters.
362

363
        :param messages: A list of ChatMessage objects representing the input messages.
364
        :param generation_kwargs: Additional keyword arguments for text generation.
365
        :param streaming_callback: An optional callable for handling streaming responses.
366
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
367
            If set, it will override the `tools` parameter provided during initialization.
368
        :returns: A dictionary with the following keys:
369
            - `replies`: A list containing the generated responses as ChatMessage instances.
370
        """
371
        if self.pipeline is None:
1✔
372
            self.warm_up()
×
373

374
        prepared_inputs = self._prepare_inputs(
1✔
375
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
376
        )
377

378
        streaming_callback = select_streaming_callback(
1✔
379
            self.streaming_callback, streaming_callback, requires_async=False
380
        )
381
        if streaming_callback:
1✔
382
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
383
            prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
1✔
384
                tokenizer=prepared_inputs["tokenizer"],
385
                stream_handler=streaming_callback,
386
                stop_words=prepared_inputs["stop_words"],
387
                component_info=ComponentInfo.from_component(self),
388
            )
389

390
        # We know it's not None because we check it in _prepare_inputs
391
        assert self.pipeline is not None
1✔
392
        # Generate responses
393
        output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
1✔
394

395
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
396

397
        return {"replies": chat_messages}
1✔
398

399
    def create_message(  # pylint: disable=too-many-positional-arguments
1✔
400
        self,
401
        text: str,
402
        index: int,
403
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
404
        prompt: str,
405
        generation_kwargs: dict[str, Any],
406
        parse_tool_calls: bool = False,
407
    ) -> ChatMessage:
408
        """
409
        Create a ChatMessage instance from the provided text, populated with metadata.
410

411
        :param text: The generated text.
412
        :param index: The index of the generated text.
413
        :param tokenizer: The tokenizer used for generation.
414
        :param prompt: The prompt used for generation.
415
        :param generation_kwargs: The generation parameters.
416
        :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
417
        :returns: A ChatMessage instance.
418
        """
419

420
        completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
1✔
421
        prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
1✔
422
        total_tokens = prompt_token_count + completion_tokens
1✔
423

424
        tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
1✔
425

426
        # Determine finish reason based on context
427
        if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
1✔
428
            finish_reason = "length"
×
429
        elif tool_calls:
1✔
430
            finish_reason = "tool_calls"
1✔
431
        else:
432
            finish_reason = "stop"
1✔
433

434
        meta = {
1✔
435
            "finish_reason": finish_reason,
436
            "index": index,
437
            "model": self.huggingface_pipeline_kwargs["model"],
438
            "usage": {
439
                "completion_tokens": completion_tokens,
440
                "prompt_tokens": prompt_token_count,
441
                "total_tokens": total_tokens,
442
            },
443
        }
444

445
        # If tool calls are detected, don't include the text content since it contains the raw tool call format
446
        return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
1✔
447

448
    @staticmethod
1✔
449
    def _validate_stop_words(stop_words: Optional[list[str]]) -> Optional[list[str]]:
1✔
450
        """
451
        Validates the provided stop words.
452

453
        :param stop_words: A list of stop words to validate.
454
        :return: A sanitized list of stop words or None if validation fails.
455
        """
456
        if stop_words and not all(isinstance(word, str) for word in stop_words):
1✔
457
            logger.warning(
×
458
                "Invalid stop words provided. Stop words must be specified as a list of strings. "
459
                "Ignoring stop words: {stop_words}",
460
                stop_words=stop_words,
461
            )
462
            return None
×
463

464
        return list(set(stop_words or []))
1✔
465

466
    @component.output_types(replies=list[ChatMessage])
1✔
467
    async def run_async(
1✔
468
        self,
469
        messages: list[ChatMessage],
470
        generation_kwargs: Optional[dict[str, Any]] = None,
471
        streaming_callback: Optional[StreamingCallbackT] = None,
472
        tools: Optional[ToolsType] = None,
473
    ) -> dict[str, list[ChatMessage]]:
474
        """
475
        Asynchronously invokes text generation inference based on the provided messages and generation parameters.
476

477
        This is the asynchronous version of the `run` method. It has the same parameters
478
        and return values but can be used with `await` in an async code.
479

480
        :param messages: A list of ChatMessage objects representing the input messages.
481
        :param generation_kwargs: Additional keyword arguments for text generation.
482
        :param streaming_callback: An optional callable for handling streaming responses.
483
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
484
            If set, it will override the `tools` parameter provided during initialization.
485
        :returns: A dictionary with the following keys:
486
            - `replies`: A list containing the generated responses as ChatMessage instances.
487
        """
488
        if self.pipeline is None:
1✔
489
            self.warm_up()
×
490

491
        prepared_inputs = self._prepare_inputs(
1✔
492
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
493
        )
494

495
        # Validate and select the streaming callback
496
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
497

498
        if streaming_callback:
1✔
499
            async_handler = AsyncHFTokenStreamingHandler(
1✔
500
                tokenizer=prepared_inputs["tokenizer"],
501
                stream_handler=streaming_callback,
502
                stop_words=prepared_inputs["stop_words"],
503
                component_info=ComponentInfo.from_component(self),
504
            )
505
            prepared_inputs["generation_kwargs"]["streamer"] = async_handler
1✔
506

507
            # Use async context manager for proper resource cleanup
508
            async with self._manage_queue_processor(async_handler):
1✔
509
                output = await asyncio.get_running_loop().run_in_executor(
1✔
510
                    self.executor,
511
                    # have to ignore since assert self.pipeline is not None doesn't work
512
                    lambda: self.pipeline(  # type: ignore[misc]
513
                        prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
514
                    ),
515
                )
516
        else:
517
            output = await asyncio.get_running_loop().run_in_executor(
1✔
518
                self.executor,
519
                # have to ignore since assert self.pipeline is not None doesn't work
520
                lambda: self.pipeline(  # type: ignore[misc]
521
                    prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
522
                ),
523
            )
524

525
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
526
        return {"replies": chat_messages}
1✔
527

528
    @asynccontextmanager
1✔
529
    async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
1✔
530
        """Context manager for proper queue processor lifecycle management."""
531
        queue_processor = asyncio.create_task(async_handler.process_queue())
1✔
532
        try:
1✔
533
            yield queue_processor
1✔
534
        finally:
535
            # Ensure the queue processor is cleaned up properly
536
            try:
1✔
537
                await asyncio.wait_for(queue_processor, timeout=0.1)
1✔
538
            except asyncio.TimeoutError:
×
539
                queue_processor.cancel()
×
540
                with suppress(asyncio.CancelledError):
×
541
                    await queue_processor
×
542

543
    def _prepare_inputs(
1✔
544
        self,
545
        messages: list[ChatMessage],
546
        generation_kwargs: Optional[dict[str, Any]] = None,
547
        streaming_callback: Optional[StreamingCallbackT] = None,
548
        tools: Optional[ToolsType] = None,
549
    ) -> dict[str, Any]:
550
        """
551
        Prepares the inputs for the Hugging Face pipeline.
552

553
        :param messages: A list of ChatMessage objects representing the input messages.
554
        :param generation_kwargs: Additional keyword arguments for text generation.
555
        :param streaming_callback: An optional callable for handling streaming responses.
556
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
557
        :returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
558
        :raises ValueError: If both tools and streaming_callback are provided.
559
        """
560
        tools = tools or self.tools
1✔
561
        if tools and streaming_callback is not None:
1✔
562
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
563
        flat_tools = flatten_tools_or_toolsets(tools)
1✔
564
        _check_duplicate_tool_names(flat_tools)
1✔
565

566
        # mypy doesn't know this is set in warm_up
567
        tokenizer = self.pipeline.tokenizer  # type: ignore[union-attr]
1✔
568

569
        # Check and update generation parameters
570
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
571

572
        # If streaming_callback is provided, ensure that num_return_sequences is set to 1
573
        if streaming_callback:
1✔
574
            num_responses = generation_kwargs.get("num_return_sequences", 1)
1✔
575
            if num_responses > 1:
1✔
576
                msg = (
×
577
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
578
                    "Streaming is only supported for single response generation. "
579
                    "Setting the number of responses to 1."
580
                )
581
                logger.warning(msg, num_responses=num_responses)
×
582
                generation_kwargs["num_return_sequences"] = 1
×
583

584
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
585
        stop_words = self._validate_stop_words(stop_words)
1✔
586

587
        # Set up stop words criteria if stop words exist
588
        stop_words_criteria = (
1✔
589
            StopWordsCriteria(
590
                tokenizer,  # type: ignore[arg-type]
591
                stop_words,
592
                self.pipeline.device,  # type: ignore[union-attr]
593
            )
594
            if stop_words
595
            else None
596
        )
597
        if stop_words_criteria:
1✔
598
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
599

600
        # convert messages to HF format
601
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
602

603
        # mypy doesn't know tokenizer is set in warm_up
604
        prepared_prompt = tokenizer.apply_chat_template(  # type: ignore[union-attr]
1✔
605
            hf_messages,
606
            tokenize=False,
607
            chat_template=self.chat_template,
608
            add_generation_prompt=True,
609
            tools=[tc.tool_spec for tc in flat_tools] if flat_tools else None,
610
            enable_thinking=self.enable_thinking,
611
        )
612
        # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
613
        assert isinstance(prepared_prompt, str)
1✔
614

615
        # Avoid some unnecessary warnings in the generation pipeline call
616
        # mypy doesn't know tokenizer is set in warm_up
617
        generation_kwargs["pad_token_id"] = (
1✔
618
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id  # type: ignore[union-attr]
619
        )
620

621
        return {
1✔
622
            "prepared_prompt": prepared_prompt,
623
            "tokenizer": tokenizer,
624
            "generation_kwargs": generation_kwargs,
625
            "tools": flat_tools,
626
            "stop_words": stop_words,
627
        }
628

629
    def _convert_hf_output_to_chat_messages(
1✔
630
        self,
631
        *,
632
        hf_pipeline_output: list[dict[str, Any]],
633
        prepared_prompt: str,
634
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
635
        generation_kwargs: dict[str, Any],
636
        stop_words: Optional[list[str]],
637
        tools: Optional[Union[list[Tool], Toolset]] = None,
638
    ) -> list[ChatMessage]:
639
        """
640
        Converts the HuggingFace pipeline output into a List of ChatMessages
641

642
        :param hf_pipeline_output: The output from the HuggingFace pipeline.
643
        :param prepared_prompt: The prompt used for generation.
644
        :param tokenizer: The tokenizer used for generation.
645
        :param generation_kwargs: The generation parameters.
646
        :param stop_words: A list of stop words to remove from the replies.
647
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
648
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
649
        """
650
        replies = [o.get("generated_text", "") for o in hf_pipeline_output]
1✔
651

652
        # Remove stop words from replies if present
653
        if stop_words:
1✔
654
            for stop_word in stop_words:
×
655
                replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
656

657
        chat_messages = [
1✔
658
            self.create_message(
659
                text=reply,
660
                index=r_index,
661
                tokenizer=tokenizer,
662
                prompt=prepared_prompt,
663
                generation_kwargs=generation_kwargs,
664
                parse_tool_calls=bool(tools),
665
            )
666
            for r_index, reply in enumerate(replies)
667
        ]
668
        return chat_messages
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc