• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13259218501

11 Feb 2025 09:00AM UTC coverage: 91.459% (-1.3%) from 92.709%
13259218501

Pull #8829

github

web-flow
Merge 427e76339 into ad90e106a
Pull Request #8829: fix: Look through all streaming chunks for tools calls

9413 of 10292 relevant lines covered (91.46%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.46
haystack/components/generators/chat/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
import re
1✔
7
import sys
1✔
8
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1✔
9

10
from haystack import component, default_from_dict, default_to_dict, logging
1✔
11
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
1✔
12
from haystack.lazy_imports import LazyImport
1✔
13
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
14
from haystack.utils import (
1✔
15
    ComponentDevice,
16
    Secret,
17
    deserialize_callable,
18
    deserialize_secrets_inplace,
19
    serialize_callable,
20
)
21

22
logger = logging.getLogger(__name__)
1✔
23

24
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
1✔
25
    from huggingface_hub import model_info
1✔
26
    from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteriaList, pipeline
1✔
27

28
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
29
        HFTokenStreamingHandler,
30
        StopWordsCriteria,
31
        convert_message_to_hf_format,
32
        deserialize_hf_model_kwargs,
33
        serialize_hf_model_kwargs,
34
    )
35

36

37
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
38

39
DEFAULT_TOOL_PATTERN = (
1✔
40
    r"(?:<tool_call>)?"
41
    r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
42
    r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
43
)
44

45

46
def default_tool_parser(text: str) -> Optional[List[ToolCall]]:
1✔
47
    """
48
    Default implementation for parsing tool calls from model output text.
49

50
    Uses DEFAULT_TOOL_PATTERN to extract tool calls.
51

52
    :param text: The text to parse for tool calls.
53
    :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
54
    """
55
    try:
1✔
56
        match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
1✔
57
    except re.error:
×
58
        logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
×
59
        return None
×
60

61
    if not match:
1✔
62
        return None
×
63

64
    name = match.group(1) or match.group(3)
1✔
65
    args_str = match.group(2) or match.group(4)
1✔
66

67
    try:
1✔
68
        arguments = json.loads(args_str)
1✔
69
        return [ToolCall(tool_name=name, arguments=arguments)]
1✔
70
    except json.JSONDecodeError:
×
71
        logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
×
72
        return None
×
73

74

75
@component
1✔
76
class HuggingFaceLocalChatGenerator:
1✔
77
    """
78
    Generates chat responses using models from Hugging Face that run locally.
79

80
    Use this component with chat-based models,
81
    such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`.
82
    LLMs running locally may need powerful hardware.
83

84
    ### Usage example
85

86
    ```python
87
    from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
88
    from haystack.dataclasses import ChatMessage
89

90
    generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
91
    generator.warm_up()
92
    messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
93
    print(generator.run(messages))
94
    ```
95

96
    ```
97
    {'replies':
98
        [ChatMessage(content=' Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
99
        with the interaction between computers and human language. It enables computers to understand, interpret, and
100
        generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
101
        analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
102
        process and derive meaning from human language, improving communication between humans and machines.',
103
        role=<ChatRole.ASSISTANT: 'assistant'>,
104
        name=None,
105
        meta={'finish_reason': 'stop', 'index': 0, 'model':
106
              'mistralai/Mistral-7B-Instruct-v0.2',
107
              'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
108
              ]
109
    }
110
    ```
111
    """
112

113
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
114
        self,
115
        model: str = "HuggingFaceH4/zephyr-7b-beta",
116
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
117
        device: Optional[ComponentDevice] = None,
118
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
119
        chat_template: Optional[str] = None,
120
        generation_kwargs: Optional[Dict[str, Any]] = None,
121
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
122
        stop_words: Optional[List[str]] = None,
123
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
124
        tools: Optional[List[Tool]] = None,
125
        tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
126
    ):
127
        """
128
        Initializes the HuggingFaceLocalChatGenerator component.
129

130
        :param model: The Hugging Face text generation model name or path,
131
            for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
132
            The model must be a chat model supporting the ChatML messaging
133
            format.
134
            If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
135
        :param task: The task for the Hugging Face pipeline. Possible options:
136
            - `text-generation`: Supported by decoder models, like GPT.
137
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
138
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
139
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
140
        :param device: The device for loading the model. If `None`, automatically selects the default device.
141
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
142
        :param token: The token to use as HTTP bearer authorization for remote files.
143
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
144
        :param chat_template: Specifies an optional Jinja template for formatting chat
145
            messages. Most high-quality chat models have their own templates, but for models without this
146
            feature or if you prefer a custom template, use this parameter.
147
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
148
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
149
            See Hugging Face's documentation for more information:
150
            - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
151
            - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
152
            The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
153
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
154
            Hugging Face pipeline for text generation.
155
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
156
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
157
            For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
158
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
159
        :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
160
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
161
            For some chat models, the output includes both the new text and the original prompt.
162
            In these cases, make sure your prompt has no stop words.
163
        :param streaming_callback: An optional callable for handling streaming responses.
164
        :param tools: A list of tools for which the model can prepare calls.
165
        :param tool_parsing_function:
166
            A callable that takes a string and returns a list of ToolCall objects or None.
167
            If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
168
        """
169
        torch_and_transformers_import.check()
1✔
170

171
        if tools and streaming_callback is not None:
1✔
172
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
173
        _check_duplicate_tool_names(tools)
1✔
174

175
        huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
1✔
176
        generation_kwargs = generation_kwargs or {}
1✔
177

178
        self.token = token
1✔
179
        token = token.resolve_value() if token else None
1✔
180

181
        # check if the huggingface_pipeline_kwargs contain the essential parameters
182
        # otherwise, populate them with values from other init parameters
183
        huggingface_pipeline_kwargs.setdefault("model", model)
1✔
184
        huggingface_pipeline_kwargs.setdefault("token", token)
1✔
185

186
        device = ComponentDevice.resolve_device(device)
1✔
187
        device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
188

189
        # task identification and validation
190
        if task is None:
1✔
191
            if "task" in huggingface_pipeline_kwargs:
1✔
192
                task = huggingface_pipeline_kwargs["task"]
1✔
193
            elif isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
194
                task = model_info(
1✔
195
                    huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
196
                ).pipeline_tag  # type: ignore[assignment]  # we'll check below if task is in supported tasks
197

198
        if task not in PIPELINE_SUPPORTED_TASKS:
1✔
199
            raise ValueError(
1✔
200
                f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
201
            )
202
        huggingface_pipeline_kwargs["task"] = task
1✔
203

204
        # if not specified, set return_full_text to False for text-generation
205
        # only generated text is returned (excluding prompt)
206
        if task == "text-generation":
1✔
207
            generation_kwargs.setdefault("return_full_text", False)
×
208

209
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
210
            raise ValueError(
×
211
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
212
                "Please specify only one of them."
213
            )
214
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
215
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
216
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
217

218
        self.tool_parsing_function = tool_parsing_function or default_tool_parser
1✔
219
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
220
        self.generation_kwargs = generation_kwargs
1✔
221
        self.chat_template = chat_template
1✔
222
        self.streaming_callback = streaming_callback
1✔
223
        self.pipeline = None
1✔
224
        self.tools = tools
1✔
225

226
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
227
        """
228
        Data that is sent to Posthog for usage analytics.
229
        """
230
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
231
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
232
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
233

234
    def warm_up(self):
1✔
235
        """
236
        Initializes the component.
237
        """
238
        if self.pipeline is None:
1✔
239
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
240

241
    def to_dict(self) -> Dict[str, Any]:
1✔
242
        """
243
        Serializes the component to a dictionary.
244

245
        :returns:
246
            Dictionary with serialized data.
247
        """
248
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
249
        serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
1✔
250
        serialization_dict = default_to_dict(
1✔
251
            self,
252
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
253
            generation_kwargs=self.generation_kwargs,
254
            streaming_callback=callback_name,
255
            token=self.token.to_dict() if self.token else None,
256
            chat_template=self.chat_template,
257
            tools=serialized_tools,
258
            tool_parsing_function=serialize_callable(self.tool_parsing_function),
259
        )
260

261
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
262
        huggingface_pipeline_kwargs.pop("token", None)
1✔
263

264
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
265
        return serialization_dict
1✔
266

267
    @classmethod
1✔
268
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
1✔
269
        """
270
        Deserializes the component from a dictionary.
271

272
        :param data:
273
            The dictionary to deserialize from.
274
        :returns:
275
            The deserialized component.
276
        """
277
        torch_and_transformers_import.check()  # leave this, cls method
1✔
278
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
279
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
280
        init_params = data.get("init_parameters", {})
1✔
281
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
282
        if serialized_callback_handler:
1✔
283
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
284

285
        tool_parsing_function = init_params.get("tool_parsing_function")
1✔
286
        if tool_parsing_function:
1✔
287
            init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
1✔
288

289
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
290
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
291
        return default_from_dict(cls, data)
1✔
292

293
    @component.output_types(replies=List[ChatMessage])
1✔
294
    def run(
1✔
295
        self,
296
        messages: List[ChatMessage],
297
        generation_kwargs: Optional[Dict[str, Any]] = None,
298
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
299
        tools: Optional[List[Tool]] = None,
300
    ):
301
        """
302
        Invoke text generation inference based on the provided messages and generation parameters.
303

304
        :param messages: A list of ChatMessage objects representing the input messages.
305
        :param generation_kwargs: Additional keyword arguments for text generation.
306
        :param streaming_callback: An optional callable for handling streaming responses.
307
        :param tools:
308
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter
309
            provided during initialization.
310
        :returns:
311
            A list containing the generated responses as ChatMessage instances.
312
        """
313
        if self.pipeline is None:
1✔
314
            raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
×
315

316
        tools = tools or self.tools
1✔
317
        if tools and streaming_callback is not None:
1✔
318
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
319
        _check_duplicate_tool_names(tools)
1✔
320

321
        tokenizer = self.pipeline.tokenizer
1✔
322

323
        # Check and update generation parameters
324
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
325

326
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
327
        # pipeline call doesn't support stop_sequences, so we need to pop it
328
        stop_words = self._validate_stop_words(stop_words)
1✔
329

330
        # Set up stop words criteria if stop words exist
331
        stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
1✔
332
        if stop_words_criteria:
1✔
333
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
334

335
        streaming_callback = streaming_callback or self.streaming_callback
1✔
336
        if streaming_callback:
1✔
337
            num_responses = generation_kwargs.get("num_return_sequences", 1)
1✔
338
            if num_responses > 1:
1✔
339
                msg = (
×
340
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
341
                    "Streaming is only supported for single response generation. "
342
                    "Setting the number of responses to 1."
343
                )
344
                logger.warning(msg, num_responses=num_responses)
×
345
                generation_kwargs["num_return_sequences"] = 1
×
346
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
347
            generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
1✔
348

349
        # convert messages to HF format
350
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
351
        prepared_prompt = tokenizer.apply_chat_template(
1✔
352
            hf_messages,
353
            tokenize=False,
354
            chat_template=self.chat_template,
355
            add_generation_prompt=True,
356
            tools=[tc.tool_spec for tc in tools] if tools else None,
357
        )
358

359
        # Avoid some unnecessary warnings in the generation pipeline call
360
        generation_kwargs["pad_token_id"] = (
1✔
361
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
362
        )
363

364
        # Generate responses
365
        output = self.pipeline(prepared_prompt, **generation_kwargs)
1✔
366
        replies = [o.get("generated_text", "") for o in output]
1✔
367

368
        # Remove stop words from replies if present
369
        for stop_word in stop_words:
1✔
370
            replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
371

372
        chat_messages = [
1✔
373
            self.create_message(
374
                reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
375
            )
376
            for r_index, reply in enumerate(replies)
377
        ]
378

379
        return {"replies": chat_messages}
1✔
380

381
    def create_message(  # pylint: disable=too-many-positional-arguments
1✔
382
        self,
383
        text: str,
384
        index: int,
385
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
386
        prompt: str,
387
        generation_kwargs: Dict[str, Any],
388
        parse_tool_calls: bool = False,
389
    ) -> ChatMessage:
390
        """
391
        Create a ChatMessage instance from the provided text, populated with metadata.
392

393
        :param text: The generated text.
394
        :param index: The index of the generated text.
395
        :param tokenizer: The tokenizer used for generation.
396
        :param prompt: The prompt used for generation.
397
        :param generation_kwargs: The generation parameters.
398
        :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
399
        :returns: A ChatMessage instance.
400
        """
401

402
        completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
1✔
403
        prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
1✔
404
        total_tokens = prompt_token_count + completion_tokens
1✔
405

406
        tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
1✔
407

408
        # Determine finish reason based on context
409
        if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
1✔
410
            finish_reason = "length"
×
411
        elif tool_calls:
1✔
412
            finish_reason = "tool_calls"
1✔
413
        else:
414
            finish_reason = "stop"
1✔
415

416
        meta = {
1✔
417
            "finish_reason": finish_reason,
418
            "index": index,
419
            "model": self.huggingface_pipeline_kwargs["model"],
420
            "usage": {
421
                "completion_tokens": completion_tokens,
422
                "prompt_tokens": prompt_token_count,
423
                "total_tokens": total_tokens,
424
            },
425
        }
426

427
        # If tool calls are detected, don't include the text content since it contains the raw tool call format
428
        return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
1✔
429

430
    def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
1✔
431
        """
432
        Validates the provided stop words.
433

434
        :param stop_words: A list of stop words to validate.
435
        :return: A sanitized list of stop words or None if validation fails.
436
        """
437
        if stop_words and not all(isinstance(word, str) for word in stop_words):
1✔
438
            logger.warning(
×
439
                "Invalid stop words provided. Stop words must be specified as a list of strings. "
440
                "Ignoring stop words: {stop_words}",
441
                stop_words=stop_words,
442
            )
443
            return None
×
444

445
        return list(set(stop_words or []))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc