• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18879776280

28 Oct 2025 03:17PM UTC coverage: 92.235% (+0.04%) from 92.194%
18879776280

Pull #9942

github

web-flow
Merge 9ca93ecfb into f98095376
Pull Request #9942: feat: Add warm_up() method to ChatGenerators for tool initialization

13494 of 14630 relevant lines covered (92.24%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.34
haystack/components/generators/chat/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import re
1✔
8
import sys
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from contextlib import asynccontextmanager, suppress
1✔
11
from typing import Any, Callable, Literal, Optional, Union
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
1✔
15
from haystack.dataclasses.streaming_chunk import select_streaming_callback
1✔
16
from haystack.lazy_imports import LazyImport
1✔
17
from haystack.tools import (
1✔
18
    Tool,
19
    Toolset,
20
    ToolsType,
21
    _check_duplicate_tool_names,
22
    deserialize_tools_or_toolset_inplace,
23
    flatten_tools_or_toolsets,
24
    serialize_tools_or_toolset,
25
)
26
from haystack.tools.utils import warm_up_tools
1✔
27
from haystack.utils import (
1✔
28
    ComponentDevice,
29
    Secret,
30
    deserialize_callable,
31
    deserialize_secrets_inplace,
32
    serialize_callable,
33
)
34

35
logger = logging.getLogger(__name__)
1✔
36

37
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
1✔
38
    from huggingface_hub import model_info
1✔
39
    from transformers import Pipeline as HfPipeline
1✔
40
    from transformers import StoppingCriteriaList, pipeline
1✔
41
    from transformers.tokenization_utils import PreTrainedTokenizer
1✔
42
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1✔
43

44
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
45
        AsyncHFTokenStreamingHandler,
46
        HFTokenStreamingHandler,
47
        StopWordsCriteria,
48
        convert_message_to_hf_format,
49
        deserialize_hf_model_kwargs,
50
        serialize_hf_model_kwargs,
51
    )
52

53

54
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
55

56
DEFAULT_TOOL_PATTERN = (
1✔
57
    r"(?:<tool_call>)?"
58
    r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
59
    r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
60
)
61

62

63
def default_tool_parser(text: str) -> Optional[list[ToolCall]]:
1✔
64
    """
65
    Default implementation for parsing tool calls from model output text.
66

67
    Uses DEFAULT_TOOL_PATTERN to extract tool calls.
68

69
    :param text: The text to parse for tool calls.
70
    :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
71
    """
72
    try:
1✔
73
        match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
1✔
74
    except re.error:
×
75
        logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
×
76
        return None
×
77

78
    if not match:
1✔
79
        return None
×
80

81
    name = match.group(1) or match.group(3)
1✔
82
    args_str = match.group(2) or match.group(4)
1✔
83

84
    try:
1✔
85
        arguments = json.loads(args_str)
1✔
86
        return [ToolCall(tool_name=name, arguments=arguments)]
1✔
87
    except json.JSONDecodeError:
×
88
        logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
×
89
        return None
×
90

91

92
@component
1✔
93
class HuggingFaceLocalChatGenerator:
1✔
94
    """
95
    Generates chat responses using models from Hugging Face that run locally.
96

97
    Use this component with chat-based models,
98
    such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`.
99
    LLMs running locally may need powerful hardware.
100

101
    ### Usage example
102

103
    ```python
104
    from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
105
    from haystack.dataclasses import ChatMessage
106

107
    generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
108
    generator.warm_up()
109
    messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
110
    print(generator.run(messages))
111
    ```
112

113
    ```
114
    {'replies':
115
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
116
        "Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
117
        with the interaction between computers and human language. It enables computers to understand, interpret, and
118
        generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
119
        analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
120
        process and derive meaning from human language, improving communication between humans and machines.")],
121
        _name=None,
122
        _meta={'finish_reason': 'stop', 'index': 0, 'model':
123
              'mistralai/Mistral-7B-Instruct-v0.2',
124
              'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
125
              ]
126
    }
127
    ```
128
    """
129

130
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
131
        self,
132
        model: str = "HuggingFaceH4/zephyr-7b-beta",
133
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
134
        device: Optional[ComponentDevice] = None,
135
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
136
        chat_template: Optional[str] = None,
137
        generation_kwargs: Optional[dict[str, Any]] = None,
138
        huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None,
139
        stop_words: Optional[list[str]] = None,
140
        streaming_callback: Optional[StreamingCallbackT] = None,
141
        tools: Optional[ToolsType] = None,
142
        tool_parsing_function: Optional[Callable[[str], Optional[list[ToolCall]]]] = None,
143
        async_executor: Optional[ThreadPoolExecutor] = None,
144
    ) -> None:
145
        """
146
        Initializes the HuggingFaceLocalChatGenerator component.
147

148
        :param model: The Hugging Face text generation model name or path,
149
            for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
150
            The model must be a chat model supporting the ChatML messaging
151
            format.
152
            If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
153
        :param task: The task for the Hugging Face pipeline. Possible options:
154
            - `text-generation`: Supported by decoder models, like GPT.
155
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
156
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
157
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
158
        :param device: The device for loading the model. If `None`, automatically selects the default device.
159
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
160
        :param token: The token to use as HTTP bearer authorization for remote files.
161
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
162
        :param chat_template: Specifies an optional Jinja template for formatting chat
163
            messages. Most high-quality chat models have their own templates, but for models without this
164
            feature or if you prefer a custom template, use this parameter.
165
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
166
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
167
            See Hugging Face's documentation for more information:
168
            - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
169
            - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
170
            The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
171
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
172
            Hugging Face pipeline for text generation.
173
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
174
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
175
            For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
176
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
177
        :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
178
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
179
            For some chat models, the output includes both the new text and the original prompt.
180
            In these cases, make sure your prompt has no stop words.
181
        :param streaming_callback: An optional callable for handling streaming responses.
182
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
183
        :param tool_parsing_function:
184
            A callable that takes a string and returns a list of ToolCall objects or None.
185
            If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
186
        :param async_executor:
187
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
188
            initialized and used
189
        """
190
        torch_and_transformers_import.check()
1✔
191

192
        if tools and streaming_callback is not None:
1✔
193
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
194
        _check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
1✔
195

196
        huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
1✔
197
        generation_kwargs = generation_kwargs or {}
1✔
198

199
        self.token = token
1✔
200
        token = token.resolve_value() if token else None
1✔
201

202
        # check if the huggingface_pipeline_kwargs contain the essential parameters
203
        # otherwise, populate them with values from other init parameters
204
        huggingface_pipeline_kwargs.setdefault("model", model)
1✔
205
        huggingface_pipeline_kwargs.setdefault("token", token)
1✔
206

207
        device = ComponentDevice.resolve_device(device)
1✔
208
        device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
209

210
        # task identification and validation
211
        if task is None:
1✔
212
            if "task" in huggingface_pipeline_kwargs:
1✔
213
                task = huggingface_pipeline_kwargs["task"]
1✔
214
            elif isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
215
                task = model_info(
1✔
216
                    huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
217
                ).pipeline_tag  # type: ignore[assignment]  # we'll check below if task is in supported tasks
218

219
        if task not in PIPELINE_SUPPORTED_TASKS:
1✔
220
            raise ValueError(
1✔
221
                f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
222
            )
223
        huggingface_pipeline_kwargs["task"] = task
1✔
224

225
        # if not specified, set return_full_text to False for text-generation
226
        # only generated text is returned (excluding prompt)
227
        if task == "text-generation":
1✔
228
            generation_kwargs.setdefault("return_full_text", False)
×
229

230
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
231
            raise ValueError(
×
232
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
233
                "Please specify only one of them."
234
            )
235
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
236
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
237
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
238

239
        self.tool_parsing_function = tool_parsing_function or default_tool_parser
1✔
240
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
241
        self.generation_kwargs = generation_kwargs
1✔
242
        self.chat_template = chat_template
1✔
243
        self.streaming_callback = streaming_callback
1✔
244
        self.pipeline: Optional[HfPipeline] = None
1✔
245
        self.tools = tools
1✔
246

247
        self._owns_executor = async_executor is None
1✔
248
        self.executor = (
1✔
249
            ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
250
            if async_executor is None
251
            else async_executor
252
        )
253
        self._is_warmed_up = False
1✔
254

255
    def __del__(self) -> None:
1✔
256
        """
257
        Cleanup when the instance is being destroyed.
258
        """
259
        if self._owns_executor:
1✔
260
            self.executor.shutdown(wait=True)
1✔
261

262
    def shutdown(self) -> None:
1✔
263
        """
264
        Explicitly shutdown the executor if we own it.
265
        """
266
        if self._owns_executor:
×
267
            self.executor.shutdown(wait=True)
×
268

269
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
270
        """
271
        Data that is sent to Posthog for usage analytics.
272
        """
273
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
274
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
275
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
276

277
    def warm_up(self) -> None:
1✔
278
        """
279
        Initializes the component and warms up tools if provided.
280
        """
281
        if self._is_warmed_up:
1✔
282
            return
1✔
283

284
        # Initialize the pipeline (existing logic)
285
        if self.pipeline is None:
1✔
286
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
287

288
        # Warm up tools (new logic)
289
        if self.tools:
1✔
290
            warm_up_tools(self.tools)
1✔
291

292
        self._is_warmed_up = True
1✔
293

294
    def to_dict(self) -> dict[str, Any]:
1✔
295
        """
296
        Serializes the component to a dictionary.
297

298
        :returns:
299
            Dictionary with serialized data.
300
        """
301
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
302
        serialization_dict = default_to_dict(
1✔
303
            self,
304
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
305
            generation_kwargs=self.generation_kwargs,
306
            streaming_callback=callback_name,
307
            token=self.token.to_dict() if self.token else None,
308
            chat_template=self.chat_template,
309
            tools=serialize_tools_or_toolset(self.tools),
310
            tool_parsing_function=serialize_callable(self.tool_parsing_function),
311
        )
312

313
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
314
        huggingface_pipeline_kwargs.pop("token", None)
1✔
315

316
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
317
        return serialization_dict
1✔
318

319
    @classmethod
1✔
320
    def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
1✔
321
        """
322
        Deserializes the component from a dictionary.
323

324
        :param data:
325
            The dictionary to deserialize from.
326
        :returns:
327
            The deserialized component.
328
        """
329
        torch_and_transformers_import.check()  # leave this, cls method
1✔
330
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
331
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
332
        init_params = data.get("init_parameters", {})
1✔
333
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
334
        if serialized_callback_handler:
1✔
335
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
336

337
        tool_parsing_function = init_params.get("tool_parsing_function")
1✔
338
        if tool_parsing_function:
1✔
339
            init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
1✔
340

341
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
342
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
343
        return default_from_dict(cls, data)
1✔
344

345
    @component.output_types(replies=list[ChatMessage])
1✔
346
    def run(
1✔
347
        self,
348
        messages: list[ChatMessage],
349
        generation_kwargs: Optional[dict[str, Any]] = None,
350
        streaming_callback: Optional[StreamingCallbackT] = None,
351
        tools: Optional[ToolsType] = None,
352
    ) -> dict[str, list[ChatMessage]]:
353
        """
354
        Invoke text generation inference based on the provided messages and generation parameters.
355

356
        :param messages: A list of ChatMessage objects representing the input messages.
357
        :param generation_kwargs: Additional keyword arguments for text generation.
358
        :param streaming_callback: An optional callable for handling streaming responses.
359
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
360
            If set, it will override the `tools` parameter provided during initialization.
361
        :returns: A dictionary with the following keys:
362
            - `replies`: A list containing the generated responses as ChatMessage instances.
363
        """
364
        prepared_inputs = self._prepare_inputs(
1✔
365
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
366
        )
367

368
        streaming_callback = select_streaming_callback(
1✔
369
            self.streaming_callback, streaming_callback, requires_async=False
370
        )
371
        if streaming_callback:
1✔
372
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
373
            prepared_inputs["generation_kwargs"]["streamer"] = HFTokenStreamingHandler(
1✔
374
                tokenizer=prepared_inputs["tokenizer"],
375
                stream_handler=streaming_callback,
376
                stop_words=prepared_inputs["stop_words"],
377
                component_info=ComponentInfo.from_component(self),
378
            )
379

380
        # We know it's not None because we check it in _prepare_inputs
381
        assert self.pipeline is not None
1✔
382
        # Generate responses
383
        output = self.pipeline(prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"])
1✔
384

385
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
386

387
        return {"replies": chat_messages}
1✔
388

389
    def create_message(  # pylint: disable=too-many-positional-arguments
1✔
390
        self,
391
        text: str,
392
        index: int,
393
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
394
        prompt: str,
395
        generation_kwargs: dict[str, Any],
396
        parse_tool_calls: bool = False,
397
    ) -> ChatMessage:
398
        """
399
        Create a ChatMessage instance from the provided text, populated with metadata.
400

401
        :param text: The generated text.
402
        :param index: The index of the generated text.
403
        :param tokenizer: The tokenizer used for generation.
404
        :param prompt: The prompt used for generation.
405
        :param generation_kwargs: The generation parameters.
406
        :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
407
        :returns: A ChatMessage instance.
408
        """
409

410
        completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
1✔
411
        prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
1✔
412
        total_tokens = prompt_token_count + completion_tokens
1✔
413

414
        tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
1✔
415

416
        # Determine finish reason based on context
417
        if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
1✔
418
            finish_reason = "length"
×
419
        elif tool_calls:
1✔
420
            finish_reason = "tool_calls"
1✔
421
        else:
422
            finish_reason = "stop"
1✔
423

424
        meta = {
1✔
425
            "finish_reason": finish_reason,
426
            "index": index,
427
            "model": self.huggingface_pipeline_kwargs["model"],
428
            "usage": {
429
                "completion_tokens": completion_tokens,
430
                "prompt_tokens": prompt_token_count,
431
                "total_tokens": total_tokens,
432
            },
433
        }
434

435
        # If tool calls are detected, don't include the text content since it contains the raw tool call format
436
        return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
1✔
437

438
    @staticmethod
1✔
439
    def _validate_stop_words(stop_words: Optional[list[str]]) -> Optional[list[str]]:
1✔
440
        """
441
        Validates the provided stop words.
442

443
        :param stop_words: A list of stop words to validate.
444
        :return: A sanitized list of stop words or None if validation fails.
445
        """
446
        if stop_words and not all(isinstance(word, str) for word in stop_words):
1✔
447
            logger.warning(
×
448
                "Invalid stop words provided. Stop words must be specified as a list of strings. "
449
                "Ignoring stop words: {stop_words}",
450
                stop_words=stop_words,
451
            )
452
            return None
×
453

454
        return list(set(stop_words or []))
1✔
455

456
    @component.output_types(replies=list[ChatMessage])
1✔
457
    async def run_async(
1✔
458
        self,
459
        messages: list[ChatMessage],
460
        generation_kwargs: Optional[dict[str, Any]] = None,
461
        streaming_callback: Optional[StreamingCallbackT] = None,
462
        tools: Optional[ToolsType] = None,
463
    ) -> dict[str, list[ChatMessage]]:
464
        """
465
        Asynchronously invokes text generation inference based on the provided messages and generation parameters.
466

467
        This is the asynchronous version of the `run` method. It has the same parameters
468
        and return values but can be used with `await` in an async code.
469

470
        :param messages: A list of ChatMessage objects representing the input messages.
471
        :param generation_kwargs: Additional keyword arguments for text generation.
472
        :param streaming_callback: An optional callable for handling streaming responses.
473
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
474
            If set, it will override the `tools` parameter provided during initialization.
475
        :returns: A dictionary with the following keys:
476
            - `replies`: A list containing the generated responses as ChatMessage instances.
477
        """
478
        prepared_inputs = self._prepare_inputs(
1✔
479
            messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
480
        )
481

482
        # Validate and select the streaming callback
483
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
484

485
        if streaming_callback:
1✔
486
            async_handler = AsyncHFTokenStreamingHandler(
1✔
487
                tokenizer=prepared_inputs["tokenizer"],
488
                stream_handler=streaming_callback,
489
                stop_words=prepared_inputs["stop_words"],
490
                component_info=ComponentInfo.from_component(self),
491
            )
492
            prepared_inputs["generation_kwargs"]["streamer"] = async_handler
1✔
493

494
            # Use async context manager for proper resource cleanup
495
            async with self._manage_queue_processor(async_handler):
1✔
496
                output = await asyncio.get_running_loop().run_in_executor(
1✔
497
                    self.executor,
498
                    # have to ignore since assert self.pipeline is not None doesn't work
499
                    lambda: self.pipeline(  # type: ignore[misc]
500
                        prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
501
                    ),
502
                )
503
        else:
504
            output = await asyncio.get_running_loop().run_in_executor(
1✔
505
                self.executor,
506
                # have to ignore since assert self.pipeline is not None doesn't work
507
                lambda: self.pipeline(  # type: ignore[misc]
508
                    prepared_inputs["prepared_prompt"], **prepared_inputs["generation_kwargs"]
509
                ),
510
            )
511

512
        chat_messages = self._convert_hf_output_to_chat_messages(hf_pipeline_output=output, **prepared_inputs)
1✔
513
        return {"replies": chat_messages}
1✔
514

515
    @asynccontextmanager
1✔
516
    async def _manage_queue_processor(self, async_handler: "AsyncHFTokenStreamingHandler"):
1✔
517
        """Context manager for proper queue processor lifecycle management."""
518
        queue_processor = asyncio.create_task(async_handler.process_queue())
1✔
519
        try:
1✔
520
            yield queue_processor
1✔
521
        finally:
522
            # Ensure the queue processor is cleaned up properly
523
            try:
1✔
524
                await asyncio.wait_for(queue_processor, timeout=0.1)
1✔
525
            except asyncio.TimeoutError:
×
526
                queue_processor.cancel()
×
527
                with suppress(asyncio.CancelledError):
×
528
                    await queue_processor
×
529

530
    def _prepare_inputs(
1✔
531
        self,
532
        messages: list[ChatMessage],
533
        generation_kwargs: Optional[dict[str, Any]] = None,
534
        streaming_callback: Optional[StreamingCallbackT] = None,
535
        tools: Optional[ToolsType] = None,
536
    ) -> dict[str, Any]:
537
        """
538
        Prepares the inputs for the Hugging Face pipeline.
539

540
        :param messages: A list of ChatMessage objects representing the input messages.
541
        :param generation_kwargs: Additional keyword arguments for text generation.
542
        :param streaming_callback: An optional callable for handling streaming responses.
543
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
544
        :returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
545
        :raises RuntimeError: If the generation model has not been loaded.
546
        :raises ValueError: If both tools and streaming_callback are provided.
547
        """
548
        if self.pipeline is None:
1✔
549
            raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
1✔
550

551
        tools = tools or self.tools
1✔
552
        if tools and streaming_callback is not None:
1✔
553
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
554
        flat_tools = flatten_tools_or_toolsets(tools)
1✔
555
        _check_duplicate_tool_names(flat_tools)
1✔
556

557
        tokenizer = self.pipeline.tokenizer
1✔
558
        # initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
559
        assert tokenizer is not None
1✔
560

561
        # Check and update generation parameters
562
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
563

564
        # If streaming_callback is provided, ensure that num_return_sequences is set to 1
565
        if streaming_callback:
1✔
566
            num_responses = generation_kwargs.get("num_return_sequences", 1)
1✔
567
            if num_responses > 1:
1✔
568
                msg = (
×
569
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
570
                    "Streaming is only supported for single response generation. "
571
                    "Setting the number of responses to 1."
572
                )
573
                logger.warning(msg, num_responses=num_responses)
×
574
                generation_kwargs["num_return_sequences"] = 1
×
575

576
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
577
        stop_words = self._validate_stop_words(stop_words)
1✔
578

579
        # Set up stop words criteria if stop words exist
580
        stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
1✔
581
        if stop_words_criteria:
1✔
582
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
583

584
        # convert messages to HF format
585
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
586

587
        prepared_prompt = tokenizer.apply_chat_template(
1✔
588
            hf_messages,
589
            tokenize=False,
590
            chat_template=self.chat_template,
591
            add_generation_prompt=True,
592
            tools=[tc.tool_spec for tc in flat_tools] if flat_tools else None,
593
        )
594
        # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
595
        assert isinstance(prepared_prompt, str)
1✔
596

597
        # Avoid some unnecessary warnings in the generation pipeline call
598
        generation_kwargs["pad_token_id"] = (
1✔
599
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
600
        )
601

602
        return {
1✔
603
            "prepared_prompt": prepared_prompt,
604
            "tokenizer": tokenizer,
605
            "generation_kwargs": generation_kwargs,
606
            "tools": flat_tools,
607
            "stop_words": stop_words,
608
        }
609

610
    def _convert_hf_output_to_chat_messages(
1✔
611
        self,
612
        *,
613
        hf_pipeline_output: list[dict[str, Any]],
614
        prepared_prompt: str,
615
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
616
        generation_kwargs: dict[str, Any],
617
        stop_words: Optional[list[str]],
618
        tools: Optional[Union[list[Tool], Toolset]] = None,
619
    ) -> list[ChatMessage]:
620
        """
621
        Converts the HuggingFace pipeline output into a List of ChatMessages
622

623
        :param hf_pipeline_output: The output from the HuggingFace pipeline.
624
        :param prepared_prompt: The prompt used for generation.
625
        :param tokenizer: The tokenizer used for generation.
626
        :param generation_kwargs: The generation parameters.
627
        :param stop_words: A list of stop words to remove from the replies.
628
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
629
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
630
        """
631
        replies = [o.get("generated_text", "") for o in hf_pipeline_output]
1✔
632

633
        # Remove stop words from replies if present
634
        if stop_words:
1✔
635
            for stop_word in stop_words:
×
636
                replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
637

638
        chat_messages = [
1✔
639
            self.create_message(
640
                text=reply,
641
                index=r_index,
642
                tokenizer=tokenizer,
643
                prompt=prepared_prompt,
644
                generation_kwargs=generation_kwargs,
645
                parse_tool_calls=bool(tools),
646
            )
647
            for r_index, reply in enumerate(replies)
648
        ]
649
        return chat_messages
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc