• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15191527043

22 May 2025 04:08PM UTC coverage: 90.345% (-0.07%) from 90.411%
15191527043

Pull #9426

github

web-flow
Merge 212e60881 into 4a5e4d3e6
Pull Request #9426: feat: add component name and type to `StreamingChunk`

11173 of 12367 relevant lines covered (90.35%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.71
haystack/utils/hf.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import copy
1✔
6
from enum import Enum
1✔
7
from typing import Any, Callable, Dict, List, Optional, Union
1✔
8

9
from haystack import logging
1✔
10
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk
1✔
11
from haystack.lazy_imports import LazyImport
1✔
12
from haystack.utils.auth import Secret
1✔
13
from haystack.utils.device import ComponentDevice
1✔
14

15
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import:
1✔
16
    import torch
1✔
17

18
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1✔
19
    from huggingface_hub import HfApi, model_info
1✔
20
    from huggingface_hub.utils import RepositoryNotFoundError
1✔
21

22
logger = logging.getLogger(__name__)
1✔
23

24

25
class HFGenerationAPIType(Enum):
1✔
26
    """
27
    API type to use for Hugging Face API Generators.
28
    """
29

30
    # HF [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference).
31
    TEXT_GENERATION_INFERENCE = "text_generation_inference"
1✔
32

33
    # HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
34
    INFERENCE_ENDPOINTS = "inference_endpoints"
1✔
35

36
    # HF [Serverless Inference API](https://huggingface.co/inference-api).
37
    SERVERLESS_INFERENCE_API = "serverless_inference_api"
1✔
38

39
    def __str__(self):
1✔
40
        return self.value
1✔
41

42
    @staticmethod
1✔
43
    def from_str(string: str) -> "HFGenerationAPIType":
1✔
44
        """
45
        Convert a string to a HFGenerationAPIType enum.
46

47
        :param string: The string to convert.
48
        :return: The corresponding HFGenerationAPIType enum.
49

50
        """
51
        enum_map = {e.value: e for e in HFGenerationAPIType}
1✔
52
        mode = enum_map.get(string)
1✔
53
        if mode is None:
1✔
54
            msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
1✔
55
            raise ValueError(msg)
1✔
56
        return mode
1✔
57

58

59
class HFEmbeddingAPIType(Enum):
1✔
60
    """
61
    API type to use for Hugging Face API Embedders.
62
    """
63

64
    # HF [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference).
65
    TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference"
1✔
66

67
    # HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
68
    INFERENCE_ENDPOINTS = "inference_endpoints"
1✔
69

70
    # HF [Serverless Inference API](https://huggingface.co/inference-api).
71
    SERVERLESS_INFERENCE_API = "serverless_inference_api"
1✔
72

73
    def __str__(self):
1✔
74
        return self.value
1✔
75

76
    @staticmethod
1✔
77
    def from_str(string: str) -> "HFEmbeddingAPIType":
1✔
78
        """
79
        Convert a string to a HFEmbeddingAPIType enum.
80

81
        :param string:
82
        :return: The corresponding HFEmbeddingAPIType enum.
83
        """
84
        enum_map = {e.value: e for e in HFEmbeddingAPIType}
1✔
85
        mode = enum_map.get(string)
1✔
86
        if mode is None:
1✔
87
            msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
1✔
88
            raise ValueError(msg)
1✔
89
        return mode
×
90

91

92
class HFModelType(Enum):
1✔
93
    EMBEDDING = 1
1✔
94
    GENERATION = 2
1✔
95

96

97
def serialize_hf_model_kwargs(kwargs: Dict[str, Any]):
1✔
98
    """
99
    Recursively serialize HuggingFace specific model keyword arguments in-place to make them JSON serializable.
100

101
    :param kwargs: The keyword arguments to serialize
102
    """
103
    torch_import.check()
1✔
104

105
    for k, v in kwargs.items():
1✔
106
        # torch.dtype
107
        if isinstance(v, torch.dtype):
1✔
108
            kwargs[k] = str(v)
1✔
109

110
        if isinstance(v, dict):
1✔
111
            serialize_hf_model_kwargs(v)
1✔
112

113

114
def deserialize_hf_model_kwargs(kwargs: Dict[str, Any]):
1✔
115
    """
116
    Recursively deserialize HuggingFace specific model keyword arguments in-place to make them JSON serializable.
117

118
    :param kwargs: The keyword arguments to deserialize
119
    """
120
    torch_import.check()
1✔
121

122
    for k, v in kwargs.items():
1✔
123
        # torch.dtype
124
        if isinstance(v, str) and v.startswith("torch."):
1✔
125
            dtype_str = v.split(".")[1]
1✔
126
            dtype = getattr(torch, dtype_str, None)
1✔
127
            if dtype is not None and isinstance(dtype, torch.dtype):
1✔
128
                kwargs[k] = dtype
1✔
129

130
        if isinstance(v, dict):
1✔
131
            deserialize_hf_model_kwargs(v)
1✔
132

133

134
def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
1✔
135
    """
136
    Update `model_kwargs` to include the keyword argument `device_map`.
137

138
    This method is useful you want to force loading a transformers model when using `AutoModel.from_pretrained` to
139
    use `device_map`.
140

141
    We handle the edge case where `device` and `device_map` is specified by ignoring the `device` parameter and printing
142
    a warning.
143

144
    :param device: The device on which the model is loaded. If `None`, the default device is automatically
145
        selected.
146
    :param model_kwargs: Additional HF keyword arguments passed to `AutoModel.from_pretrained`.
147
        For details on what kwargs you can pass, see the model's documentation.
148
    """
149
    model_kwargs = copy.copy(model_kwargs) or {}
1✔
150
    if model_kwargs.get("device_map"):
1✔
151
        if device is not None:
1✔
152
            logger.warning(
1✔
153
                "The parameters `device` and `device_map` from `model_kwargs` are both provided. "
154
                "Ignoring `device` and using `device_map`."
155
            )
156
        # Resolve device if device_map is provided in model_kwargs
157
        device_map = model_kwargs["device_map"]
1✔
158
    else:
159
        device_map = ComponentDevice.resolve_device(device).to_hf()
1✔
160

161
    # Set up device_map which allows quantized loading and multi device inference
162
    # requires accelerate which is always installed when using `pip install transformers[torch]`
163
    model_kwargs["device_map"] = device_map
1✔
164

165
    return model_kwargs
1✔
166

167

168
def resolve_hf_pipeline_kwargs(  # pylint: disable=too-many-positional-arguments
1✔
169
    huggingface_pipeline_kwargs: Dict[str, Any],
170
    model: str,
171
    task: Optional[str],
172
    supported_tasks: List[str],
173
    device: Optional[ComponentDevice],
174
    token: Optional[Secret],
175
) -> Dict[str, Any]:
176
    """
177
    Resolve the HuggingFace pipeline keyword arguments based on explicit user inputs.
178

179
    :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize a
180
        Hugging Face pipeline.
181
    :param model: The name or path of a Hugging Face model for on the HuggingFace Hub.
182
    :param task: The task for the Hugging Face pipeline.
183
    :param supported_tasks: The list of supported tasks to check the task of the model against. If the task of the model
184
        is not present within this list then a ValueError is thrown.
185
    :param device: The device on which the model is loaded. If `None`, the default device is automatically
186
        selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
187
    :param token: The token to use as HTTP bearer authorization for remote files.
188
        If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
189
    """
190
    huggingface_hub_import.check()
1✔
191

192
    token = token.resolve_value() if token else None
1✔
193
    # check if the huggingface_pipeline_kwargs contain the essential parameters
194
    # otherwise, populate them with values from other init parameters
195
    huggingface_pipeline_kwargs.setdefault("model", model)
1✔
196
    huggingface_pipeline_kwargs.setdefault("token", token)
1✔
197

198
    device = ComponentDevice.resolve_device(device)
1✔
199
    device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
200

201
    # task identification and validation
202
    task = task or huggingface_pipeline_kwargs.get("task")
1✔
203
    if task is None and isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
204
        task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag
1✔
205

206
    if task not in supported_tasks:
1✔
207
        raise ValueError(f"Task '{task}' is not supported. The supported tasks are: {', '.join(supported_tasks)}.")
1✔
208
    huggingface_pipeline_kwargs["task"] = task
1✔
209
    return huggingface_pipeline_kwargs
1✔
210

211

212
def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None:
1✔
213
    """
214
    Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
215

216
    Also check if the model is an embedding or generation model.
217

218
    :param model_id: A string representing the HuggingFace model ID.
219
    :param model_type: the model type, HFModelType.EMBEDDING or HFModelType.GENERATION
220
    :param token: The optional authentication token.
221
    :raises ValueError: If the model is not found or is not a embedding model.
222
    """
223
    huggingface_hub_import.check()
×
224

225
    api = HfApi()
×
226
    try:
×
227
        model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
×
228
    except RepositoryNotFoundError as e:
×
229
        raise ValueError(
×
230
            f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
231
        ) from e
232

233
    if model_type == HFModelType.EMBEDDING:
×
234
        allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
×
235
        error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model."
×
236
    elif model_type == HFModelType.GENERATION:
×
237
        allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
×
238
        error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model."
×
239
    else:
240
        allowed_model = False
×
241
        error_msg = f"Unknown model type for {model_id}"
×
242

243
    if not allowed_model:
×
244
        raise ValueError(error_msg)
×
245

246

247
def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
1✔
248
    """
249
    Convert a message to the format expected by Hugging Face.
250
    """
251
    text_contents = message.texts
1✔
252
    tool_calls = message.tool_calls
1✔
253
    tool_call_results = message.tool_call_results
1✔
254

255
    if not text_contents and not tool_calls and not tool_call_results:
1✔
256
        raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
1✔
257
    if len(text_contents) + len(tool_call_results) > 1:
1✔
258
        raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
1✔
259

260
    # HF always expects a content field, even if it is empty
261
    hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""}
1✔
262

263
    if tool_call_results:
1✔
264
        result = tool_call_results[0]
1✔
265
        hf_msg["content"] = result.result
1✔
266
        if tc_id := result.origin.id:
1✔
267
            hf_msg["tool_call_id"] = tc_id
1✔
268
        # HF does not provide a way to communicate errors in tool invocations, so we ignore the error field
269
        return hf_msg
1✔
270

271
    if text_contents:
1✔
272
        hf_msg["content"] = text_contents[0]
1✔
273
    if tool_calls:
1✔
274
        hf_tool_calls = []
1✔
275
        for tc in tool_calls:
1✔
276
            hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}}
1✔
277
            if tc.id is not None:
1✔
278
                hf_tool_call["id"] = tc.id
1✔
279
            hf_tool_calls.append(hf_tool_call)
1✔
280
        hf_msg["tool_calls"] = hf_tool_calls
1✔
281

282
    return hf_msg
1✔
283

284

285
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
1✔
286
    from transformers import StoppingCriteria, TextStreamer
1✔
287
    from transformers.tokenization_utils import PreTrainedTokenizer
1✔
288
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1✔
289

290
    torch_import.check()
1✔
291
    transformers_import.check()
1✔
292

293
    class StopWordsCriteria(StoppingCriteria):
1✔
294
        """
295
        Stops text generation in HuggingFace generators if any one of the stop words is generated.
296

297
        Note: When a stop word is encountered, the generation of new text is stopped.
298
        However, if the stop word is in the prompt itself, it can stop generating new text
299
        prematurely after the first token. This is particularly important for LLMs designed
300
        for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
301
        the output includes both the new text and the original prompt. Therefore, it's important
302
        to make sure your prompt has no stop words.
303
        """
304

305
        def __init__(
1✔
306
            self,
307
            tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
308
            stop_words: List[str],
309
            device: Union[str, torch.device] = "cpu",
310
        ):
311
            super().__init__()
1✔
312
            # check if tokenizer is a valid tokenizer
313
            if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
1✔
314
                raise ValueError(
×
315
                    f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
316
                    f"Please provide a valid tokenizer from the HuggingFace Transformers library."
317
                )
318
            if not tokenizer.pad_token:
1✔
319
                if tokenizer.eos_token:
×
320
                    tokenizer.pad_token = tokenizer.eos_token
×
321
                else:
322
                    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
×
323
            encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
1✔
324
            self.stop_ids = encoded_stop_words.input_ids.to(device)
1✔
325

326
        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
1✔
327
            """Check if any of the stop words are generated in the current text generation step."""
328
            for stop_id in self.stop_ids:
1✔
329
                found_stop_word = self.is_stop_word_found(input_ids, stop_id)
1✔
330
                if found_stop_word:
1✔
331
                    return True
1✔
332
            return False
1✔
333

334
        @staticmethod
1✔
335
        def is_stop_word_found(generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
1✔
336
            """
337
            Performs phrase matching.
338

339
            Checks if a sequence of stop tokens appears in a continuous or sequential order within the generated text.
340
            """
341
            generated_text_ids = generated_text_ids[-1]
1✔
342
            len_generated_text_ids = generated_text_ids.size(0)
1✔
343
            len_stop_id = stop_id.size(0)
1✔
344
            result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
1✔
345
            return result
1✔
346

347
    class HFTokenStreamingHandler(TextStreamer):
1✔
348
        """
349
        Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
350

351
        Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
352
        of generated text via Haystack Callable[StreamingChunk, None] callbacks.
353

354
        Do not use this class directly.
355
        """
356

357
        def __init__(
1✔
358
            self,
359
            tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
360
            stream_handler: Callable[[StreamingChunk], None],
361
            stop_words: Optional[List[str]] = None,
362
            component_info: ComponentInfo = ComponentInfo(),
363
        ):
364
            super().__init__(tokenizer=tokenizer, skip_prompt=True)  # type: ignore
1✔
365
            self.token_handler = stream_handler
1✔
366
            self.stop_words = stop_words or []
1✔
367
            self.component_info = component_info
1✔
368

369
        def on_finalized_text(self, word: str, stream_end: bool = False):
1✔
370
            """Callback function for handling the generated text."""
371
            word_to_send = word + "\n" if stream_end else word
×
372
            if word_to_send.strip() not in self.stop_words:
×
373
                self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info))
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc