• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14331194913

08 Apr 2025 10:25AM UTC coverage: 90.128% (+0.04%) from 90.085%
14331194913

Pull #9182

github

web-flow
Merge 2f5d1c5e9 into 2665d048b
Pull Request #9182: enhancement: Add attributes to PipelineRuntimeError

10664 of 11832 relevant lines covered (90.13%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.27
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1✔
10
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
11
from haystack.lazy_imports import LazyImport
1✔
12
from haystack.tools import (
1✔
13
    Tool,
14
    Toolset,
15
    _check_duplicate_tool_names,
16
    deserialize_tools_or_toolset_inplace,
17
    serialize_tools_or_toolset,
18
)
19
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
20
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
21
from haystack.utils.url_validation import is_valid_http_url
1✔
22

23
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
24
    from huggingface_hub import (
1✔
25
        AsyncInferenceClient,
26
        ChatCompletionInputFunctionDefinition,
27
        ChatCompletionInputTool,
28
        ChatCompletionOutput,
29
        ChatCompletionStreamOutput,
30
        InferenceClient,
31
    )
32

33

34
@component
1✔
35
class HuggingFaceAPIChatGenerator:
1✔
36
    """
37
    Completes chats using Hugging Face APIs.
38

39
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
40
    format for input and output. Use it to generate text with Hugging Face APIs:
41
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
42
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
43
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
44

45
    ### Usage examples
46

47
    #### With the free serverless inference API
48

49
    ```python
50
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
51
    from haystack.dataclasses import ChatMessage
52
    from haystack.utils import Secret
53
    from haystack.utils.hf import HFGenerationAPIType
54

55
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
56
                ChatMessage.from_user("What's Natural Language Processing?")]
57

58
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
59
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
60
    api_type = "serverless_inference_api" # this is equivalent to the above
61

62
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
63
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
64
                                            token=Secret.from_token("<your-api-key>"))
65

66
    result = generator.run(messages)
67
    print(result)
68
    ```
69

70
    #### With paid inference endpoints
71

72
    ```python
73
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
74
    from haystack.dataclasses import ChatMessage
75
    from haystack.utils import Secret
76

77
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
78
                ChatMessage.from_user("What's Natural Language Processing?")]
79

80
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
81
                                            api_params={"url": "<your-inference-endpoint-url>"},
82
                                            token=Secret.from_token("<your-api-key>"))
83

84
    result = generator.run(messages)
85
    print(result)
86

87
    #### With self-hosted text generation inference
88

89
    ```python
90
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
91
    from haystack.dataclasses import ChatMessage
92

93
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
94
                ChatMessage.from_user("What's Natural Language Processing?")]
95

96
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
97
                                            api_params={"url": "http://localhost:8080"})
98

99
    result = generator.run(messages)
100
    print(result)
101
    ```
102
    """
103

104
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
105
        self,
106
        api_type: Union[HFGenerationAPIType, str],
107
        api_params: Dict[str, str],
108
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
109
        generation_kwargs: Optional[Dict[str, Any]] = None,
110
        stop_words: Optional[List[str]] = None,
111
        streaming_callback: Optional[StreamingCallbackT] = None,
112
        tools: Optional[Union[List[Tool], Toolset]] = None,
113
    ):
114
        """
115
        Initialize the HuggingFaceAPIChatGenerator instance.
116

117
        :param api_type:
118
            The type of Hugging Face API to use. Available types:
119
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
120
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
121
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
122
        :param api_params:
123
            A dictionary with the following keys:
124
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
125
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
126
            `TEXT_GENERATION_INFERENCE`.
127
        :param token:
128
            The Hugging Face token to use as HTTP bearer authorization.
129
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
130
        :param generation_kwargs:
131
            A dictionary with keyword arguments to customize text generation.
132
                Some examples: `max_tokens`, `temperature`, `top_p`.
133
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
134
        :param stop_words:
135
            An optional list of strings representing the stop words.
136
        :param streaming_callback:
137
            An optional callable for handling streaming responses.
138
        :param tools:
139
            A list of tools or a Toolset for which the model can prepare calls.
140
            The chosen model should support tool/function calling, according to the model card.
141
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
142
            unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
143
        """
144

145
        huggingface_hub_import.check()
1✔
146

147
        if isinstance(api_type, str):
1✔
148
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
149

150
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
151
            model = api_params.get("model")
1✔
152
            if model is None:
1✔
153
                raise ValueError(
1✔
154
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
155
                )
156
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
157
            model_or_url = model
1✔
158
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
159
            url = api_params.get("url")
1✔
160
            if url is None:
1✔
161
                msg = (
1✔
162
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
163
                    "in `api_params`."
164
                )
165
                raise ValueError(msg)
1✔
166
            if not is_valid_http_url(url):
1✔
167
                raise ValueError(f"Invalid URL: {url}")
1✔
168
            model_or_url = url
1✔
169
        else:
170
            msg = f"Unknown api_type {api_type}"
×
171
            raise ValueError(msg)
×
172

173
        if tools and streaming_callback is not None:
1✔
174
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
175
        _check_duplicate_tool_names(list(tools or []))
1✔
176

177
        # handle generation kwargs setup
178
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
179
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
180
        generation_kwargs["stop"].extend(stop_words or [])
1✔
181
        generation_kwargs.setdefault("max_tokens", 512)
1✔
182

183
        self.api_type = api_type
1✔
184
        self.api_params = api_params
1✔
185
        self.token = token
1✔
186
        self.generation_kwargs = generation_kwargs
1✔
187
        self.streaming_callback = streaming_callback
1✔
188
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
189
        self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
190
        self.tools = tools
1✔
191

192
    def to_dict(self) -> Dict[str, Any]:
1✔
193
        """
194
        Serialize this component to a dictionary.
195

196
        :returns:
197
            A dictionary containing the serialized component.
198
        """
199
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
200
        return default_to_dict(
1✔
201
            self,
202
            api_type=str(self.api_type),
203
            api_params=self.api_params,
204
            token=self.token.to_dict() if self.token else None,
205
            generation_kwargs=self.generation_kwargs,
206
            streaming_callback=callback_name,
207
            tools=serialize_tools_or_toolset(self.tools),
208
        )
209

210
    @classmethod
1✔
211
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
212
        """
213
        Deserialize this component from a dictionary.
214
        """
215
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
216
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
217
        init_params = data.get("init_parameters", {})
1✔
218
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
219
        if serialized_callback_handler:
1✔
220
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
221
        return default_from_dict(cls, data)
1✔
222

223
    @component.output_types(replies=List[ChatMessage])
1✔
224
    def run(
1✔
225
        self,
226
        messages: List[ChatMessage],
227
        generation_kwargs: Optional[Dict[str, Any]] = None,
228
        tools: Optional[Union[List[Tool], Toolset]] = None,
229
        streaming_callback: Optional[StreamingCallbackT] = None,
230
    ):
231
        """
232
        Invoke the text generation inference based on the provided messages and generation parameters.
233

234
        :param messages:
235
            A list of ChatMessage objects representing the input messages.
236
        :param generation_kwargs:
237
            Additional keyword arguments for text generation.
238
        :param tools:
239
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override
240
            the `tools` parameter set during component initialization. This parameter can accept either a
241
            list of `Tool` objects or a `Toolset` instance.
242
        :param streaming_callback:
243
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
244
            parameter set during component initialization.
245
        :returns: A dictionary with the following keys:
246
            - `replies`: A list containing the generated responses as ChatMessage objects.
247
        """
248

249
        # update generation kwargs by merging with the default ones
250
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
251

252
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
253

254
        tools = tools or self.tools
1✔
255
        if tools and self.streaming_callback:
1✔
256
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
257
        _check_duplicate_tool_names(list(tools or []))
1✔
258

259
        # validate and select the streaming callback
260
        streaming_callback = select_streaming_callback(
1✔
261
            self.streaming_callback, streaming_callback, requires_async=False
262
        )
263

264
        if streaming_callback:
1✔
265
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
266

267
        hf_tools = None
1✔
268
        if tools:
1✔
269
            if isinstance(tools, Toolset):
1✔
270
                tools = list(tools)
×
271
            hf_tools = [
1✔
272
                ChatCompletionInputTool(
273
                    function=ChatCompletionInputFunctionDefinition(
274
                        name=tool.name, description=tool.description, arguments=tool.parameters
275
                    ),
276
                    type="function",
277
                )
278
                for tool in tools
279
            ]
280
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
281

282
    @component.output_types(replies=List[ChatMessage])
1✔
283
    async def run_async(
1✔
284
        self,
285
        messages: List[ChatMessage],
286
        generation_kwargs: Optional[Dict[str, Any]] = None,
287
        tools: Optional[Union[List[Tool], Toolset]] = None,
288
        streaming_callback: Optional[StreamingCallbackT] = None,
289
    ):
290
        """
291
        Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
292

293
        This is the asynchronous version of the `run` method. It has the same parameters
294
        and return values but can be used with `await` in an async code.
295

296
        :param messages:
297
            A list of ChatMessage objects representing the input messages.
298
        :param generation_kwargs:
299
            Additional keyword arguments for text generation.
300
        :param tools:
301
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
302
            parameter set during component initialization. This parameter can accept either a list of `Tool` objects
303
            or a `Toolset` instance.
304
        :param streaming_callback:
305
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
306
            parameter set during component initialization.
307
        :returns: A dictionary with the following keys:
308
            - `replies`: A list containing the generated responses as ChatMessage objects.
309
        """
310

311
        # update generation kwargs by merging with the default ones
312
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
313

314
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
315

316
        tools = tools or self.tools
1✔
317
        if tools and self.streaming_callback:
1✔
318
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
319
        _check_duplicate_tool_names(list(tools or []))
1✔
320

321
        # validate and select the streaming callback
322
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
323

324
        if streaming_callback:
1✔
325
            return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
1✔
326

327
        hf_tools = None
1✔
328
        if tools:
1✔
329
            if isinstance(tools, Toolset):
1✔
330
                tools = list(tools)
×
331
            hf_tools = [
1✔
332
                ChatCompletionInputTool(
333
                    function=ChatCompletionInputFunctionDefinition(
334
                        name=tool.name, description=tool.description, arguments=tool.parameters
335
                    ),
336
                    type="function",
337
                )
338
                for tool in tools
339
            ]
340
        return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
1✔
341

342
    def _run_streaming(
1✔
343
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
344
    ):
345
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
346
            messages, stream=True, **generation_kwargs
347
        )
348

349
        generated_text = ""
1✔
350
        first_chunk_time = None
1✔
351

352
        for chunk in api_output:
1✔
353
            # n is unused, so the API always returns only one choice
354
            # the argument is probably allowed for compatibility with OpenAI
355
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
356
            choice = chunk.choices[0]
1✔
357

358
            text = choice.delta.content or ""
1✔
359
            generated_text += text
1✔
360

361
            finish_reason = choice.finish_reason
1✔
362

363
            meta: Dict[str, Any] = {}
1✔
364
            if finish_reason:
1✔
365
                meta["finish_reason"] = finish_reason
1✔
366

367
            if first_chunk_time is None:
1✔
368
                first_chunk_time = datetime.now().isoformat()
1✔
369

370
            stream_chunk = StreamingChunk(text, meta)
1✔
371
            streaming_callback(stream_chunk)
1✔
372

373
        meta.update(
1✔
374
            {
375
                "model": self._client.model,
376
                "finish_reason": finish_reason,
377
                "index": 0,
378
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
379
                "completion_start_time": first_chunk_time,
380
            }
381
        )
382

383
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
384

385
        return {"replies": [message]}
1✔
386

387
    def _run_non_streaming(
1✔
388
        self,
389
        messages: List[Dict[str, str]],
390
        generation_kwargs: Dict[str, Any],
391
        tools: Optional[List["ChatCompletionInputTool"]] = None,
392
    ) -> Dict[str, List[ChatMessage]]:
393
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
394
            messages=messages, tools=tools, **generation_kwargs
395
        )
396

397
        if len(api_chat_output.choices) == 0:
1✔
398
            return {"replies": []}
×
399

400
        # n is unused, so the API always returns only one choice
401
        # the argument is probably allowed for compatibility with OpenAI
402
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
403
        choice = api_chat_output.choices[0]
1✔
404

405
        text = choice.message.content
1✔
406
        tool_calls = []
1✔
407

408
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
409
            for hfapi_tc in hfapi_tool_calls:
1✔
410
                tool_call = ToolCall(
1✔
411
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
412
                )
413
                tool_calls.append(tool_call)
1✔
414

415
        meta: Dict[str, Any] = {
1✔
416
            "model": self._client.model,
417
            "finish_reason": choice.finish_reason,
418
            "index": choice.index,
419
        }
420

421
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
422
        if api_chat_output.usage:
1✔
423
            usage = {
1✔
424
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
425
                "completion_tokens": api_chat_output.usage.completion_tokens,
426
            }
427
        meta["usage"] = usage
1✔
428

429
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
430
        return {"replies": [message]}
1✔
431

432
    async def _run_streaming_async(
1✔
433
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
434
    ):
435
        api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
1✔
436
            messages, stream=True, **generation_kwargs
437
        )
438

439
        generated_text = ""
1✔
440
        first_chunk_time = None
1✔
441

442
        async for chunk in api_output:
1✔
443
            choice = chunk.choices[0]
1✔
444

445
            text = choice.delta.content or ""
1✔
446
            generated_text += text
1✔
447

448
            finish_reason = choice.finish_reason
1✔
449

450
            meta: Dict[str, Any] = {}
1✔
451
            if finish_reason:
1✔
452
                meta["finish_reason"] = finish_reason
1✔
453

454
            if first_chunk_time is None:
1✔
455
                first_chunk_time = datetime.now().isoformat()
1✔
456

457
            stream_chunk = StreamingChunk(text, meta)
1✔
458
            await streaming_callback(stream_chunk)  # type: ignore
1✔
459

460
        meta.update(
1✔
461
            {
462
                "model": self._async_client.model,
463
                "finish_reason": finish_reason,
464
                "index": 0,
465
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},
466
                "completion_start_time": first_chunk_time,
467
            }
468
        )
469

470
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
471
        return {"replies": [message]}
1✔
472

473
    async def _run_non_streaming_async(
1✔
474
        self,
475
        messages: List[Dict[str, str]],
476
        generation_kwargs: Dict[str, Any],
477
        tools: Optional[List["ChatCompletionInputTool"]] = None,
478
    ) -> Dict[str, List[ChatMessage]]:
479
        api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
1✔
480
            messages=messages, tools=tools, **generation_kwargs
481
        )
482

483
        if len(api_chat_output.choices) == 0:
1✔
484
            return {"replies": []}
×
485

486
        choice = api_chat_output.choices[0]
1✔
487

488
        text = choice.message.content
1✔
489
        tool_calls = []
1✔
490

491
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
492
            for hfapi_tc in hfapi_tool_calls:
1✔
493
                tool_call = ToolCall(
1✔
494
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
495
                )
496
                tool_calls.append(tool_call)
1✔
497

498
        meta: Dict[str, Any] = {
1✔
499
            "model": self._async_client.model,
500
            "finish_reason": choice.finish_reason,
501
            "index": choice.index,
502
        }
503

504
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
505
        if api_chat_output.usage:
1✔
506
            usage = {
1✔
507
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
508
                "completion_tokens": api_chat_output.usage.completion_tokens,
509
            }
510
        meta["usage"] = usage
1✔
511

512
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
513
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc