• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15131674881

20 May 2025 07:35AM UTC coverage: 90.156% (-0.3%) from 90.471%
15131674881

Pull #9407

github

web-flow
Merge b382eca10 into 6ad23f822
Pull Request #9407: feat: stream `ToolResult` from run_async in Agent

10972 of 12170 relevant lines covered (90.16%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.76
haystack/components/generators/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, List, Literal, Optional, cast
1✔
6

7
from haystack import component, default_from_dict, default_to_dict, logging
1✔
8
from haystack.dataclasses import StreamingChunk
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import (
1✔
11
    ComponentDevice,
12
    Secret,
13
    deserialize_callable,
14
    deserialize_secrets_inplace,
15
    serialize_callable,
16
)
17
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
18

19
logger = logging.getLogger(__name__)
1✔
20

21
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
22

23
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
1✔
24
    from transformers import Pipeline, StoppingCriteriaList, pipeline
1✔
25

26
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
27
        HFTokenStreamingHandler,
28
        StopWordsCriteria,
29
        resolve_hf_pipeline_kwargs,
30
    )
31

32

33
@component
1✔
34
class HuggingFaceLocalGenerator:
1✔
35
    """
36
    Generates text using models from Hugging Face that run locally.
37

38
    LLMs running locally may need powerful hardware.
39

40
    ### Usage example
41

42
    ```python
43
    from haystack.components.generators import HuggingFaceLocalGenerator
44

45
    generator = HuggingFaceLocalGenerator(
46
        model="google/flan-t5-large",
47
        task="text2text-generation",
48
        generation_kwargs={"max_new_tokens": 100, "temperature": 0.9})
49

50
    generator.warm_up()
51

52
    print(generator.run("Who is the best American actor?"))
53
    # {'replies': ['John Cusack']}
54
    ```
55
    """
56

57
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
58
        self,
59
        model: str = "google/flan-t5-base",
60
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
61
        device: Optional[ComponentDevice] = None,
62
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
63
        generation_kwargs: Optional[Dict[str, Any]] = None,
64
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
65
        stop_words: Optional[List[str]] = None,
66
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
67
    ):
68
        """
69
        Creates an instance of a HuggingFaceLocalGenerator.
70

71
        :param model: The Hugging Face text generation model name or path.
72
        :param task: The task for the Hugging Face pipeline. Possible options:
73
            - `text-generation`: Supported by decoder models, like GPT.
74
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
75
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
76
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
77
        :param device: The device for loading the model. If `None`, automatically selects the default device.
78
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
79
        :param token: The token to use as HTTP bearer authorization for remote files.
80
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
81
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
82
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
83
            See Hugging Face's documentation for more information:
84
            - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
85
            - [transformers.GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
86
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
87
            Hugging Face pipeline for text generation.
88
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
89
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
90
            For available kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
91
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for model initialization:
92
            [transformers.PreTrainedModel.from_pretrained](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
93
        :param stop_words: If the model generates a stop word, the generation stops.
94
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
95
            For some chat models, the output includes both the new text and the original prompt.
96
            In these cases, make sure your prompt has no stop words.
97
        :param streaming_callback: An optional callable for handling streaming responses.
98
        """
99
        transformers_import.check()
1✔
100

101
        self.token = token
1✔
102
        generation_kwargs = generation_kwargs or {}
1✔
103

104
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
105
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
106
            model=model,
107
            task=task,
108
            supported_tasks=SUPPORTED_TASKS,
109
            device=device,
110
            token=token,
111
        )
112

113
        # if not specified, set return_full_text to False for text-generation
114
        # only generated text is returned (excluding prompt)
115
        task = huggingface_pipeline_kwargs["task"]
1✔
116
        if task == "text-generation":
1✔
117
            generation_kwargs.setdefault("return_full_text", False)
1✔
118

119
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
120
            raise ValueError(
1✔
121
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
122
                "Please specify only one of them."
123
            )
124
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
125

126
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
127
        self.generation_kwargs = generation_kwargs
1✔
128
        self.stop_words = stop_words
1✔
129
        self.pipeline: Optional[Pipeline] = None
1✔
130
        self.stopping_criteria_list: Optional[StoppingCriteriaList] = None
1✔
131
        self.streaming_callback = streaming_callback
1✔
132

133
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
134
        """
135
        Data that is sent to Posthog for usage analytics.
136
        """
137
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
138
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
139
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
140

141
    @property
1✔
142
    def _warmed_up(self) -> bool:
1✔
143
        if self.stop_words:
1✔
144
            return (self.pipeline is not None) and (self.stopping_criteria_list is not None)
1✔
145
        return self.pipeline is not None
1✔
146

147
    def warm_up(self):
1✔
148
        """
149
        Initializes the component.
150
        """
151
        if self._warmed_up:
1✔
152
            return
1✔
153

154
        if self.pipeline is None:
1✔
155
            self.pipeline = cast(Pipeline, pipeline(**self.huggingface_pipeline_kwargs))
1✔
156

157
        if self.stop_words:
1✔
158
            # text-generation and text2text-generation pipelines always have a non-None tokenizer
159
            assert self.pipeline.tokenizer is not None
1✔
160

161
            stop_words_criteria = StopWordsCriteria(
1✔
162
                tokenizer=self.pipeline.tokenizer, stop_words=self.stop_words, device=self.pipeline.device
163
            )
164
            self.stopping_criteria_list = StoppingCriteriaList([stop_words_criteria])
1✔
165

166
    def to_dict(self) -> Dict[str, Any]:
1✔
167
        """
168
        Serializes the component to a dictionary.
169

170
        :returns:
171
            Dictionary with serialized data.
172
        """
173
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
174
        serialization_dict = default_to_dict(
1✔
175
            self,
176
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
177
            generation_kwargs=self.generation_kwargs,
178
            streaming_callback=callback_name,
179
            stop_words=self.stop_words,
180
            token=self.token.to_dict() if self.token else None,
181
        )
182

183
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
184
        huggingface_pipeline_kwargs.pop("token", None)
1✔
185

186
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
187
        return serialization_dict
1✔
188

189
    @classmethod
1✔
190
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
1✔
191
        """
192
        Deserializes the component from a dictionary.
193

194
        :param data:
195
            The dictionary to deserialize from.
196
        :returns:
197
            The deserialized component.
198
        """
199
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
200
        init_params = data.get("init_parameters", {})
1✔
201
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
202
        if serialized_callback_handler:
1✔
203
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
204

205
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
206
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
207
        return default_from_dict(cls, data)
1✔
208

209
    @component.output_types(replies=List[str])
1✔
210
    def run(
1✔
211
        self,
212
        prompt: str,
213
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
214
        generation_kwargs: Optional[Dict[str, Any]] = None,
215
    ):
216
        """
217
        Run the text generation model on the given prompt.
218

219
        :param prompt:
220
            A string representing the prompt.
221
        :param streaming_callback:
222
            A callback function that is called when a new token is received from the stream.
223
        :param generation_kwargs:
224
            Additional keyword arguments for text generation.
225

226
        :returns:
227
            A dictionary containing the generated replies.
228
            - replies: A list of strings representing the generated replies.
229
        """
230
        if not self._warmed_up:
1✔
231
            raise RuntimeError(
1✔
232
                "The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
233
            )
234

235
        if not prompt:
1✔
236
            return {"replies": []}
1✔
237

238
        # merge generation kwargs from init method with those from run method
239
        updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
240

241
        # check if streaming_callback is passed
242
        streaming_callback = streaming_callback or self.streaming_callback
1✔
243

244
        if streaming_callback:
1✔
245
            num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
1✔
246
            if num_responses > 1:
1✔
247
                msg = (
×
248
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
249
                    "Streaming is only supported for single response generation. "
250
                    "Setting the number of responses to 1."
251
                )
252
                logger.warning(msg, num_responses=num_responses)
×
253
                updated_generation_kwargs["num_return_sequences"] = 1
×
254
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
255
            updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
1✔
256
                self.pipeline.tokenizer,  # type: ignore
257
                streaming_callback,
258
                self.stop_words,  # type: ignore
259
            )
260

261
        output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)  # type: ignore
1✔
262
        replies = [o["generated_text"] for o in output if "generated_text" in o]
1✔
263

264
        if self.stop_words:
1✔
265
            # the output of the pipeline includes the stop word
266
            replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
1✔
267

268
        return {"replies": replies}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc