• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14064199728

25 Mar 2025 03:52PM UTC coverage: 90.154% (+0.08%) from 90.07%
14064199728

Pull #9055

github

web-flow
Merge eaafb5e56 into e64db6197
Pull Request #9055: Added retries parameters to pipeline.draw()

9898 of 10979 relevant lines covered (90.15%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.36
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
12
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
13
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
14
from haystack.utils.url_validation import is_valid_http_url
1✔
15

16
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
17
    from huggingface_hub import (
1✔
18
        AsyncInferenceClient,
19
        ChatCompletionInputFunctionDefinition,
20
        ChatCompletionInputTool,
21
        ChatCompletionOutput,
22
        ChatCompletionStreamOutput,
23
        InferenceClient,
24
    )
25

26

27
@component
1✔
28
class HuggingFaceAPIChatGenerator:
1✔
29
    """
30
    Completes chats using Hugging Face APIs.
31

32
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
33
    format for input and output. Use it to generate text with Hugging Face APIs:
34
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
35
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
36
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
37

38
    ### Usage examples
39

40
    #### With the free serverless inference API
41

42
    ```python
43
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
44
    from haystack.dataclasses import ChatMessage
45
    from haystack.utils import Secret
46
    from haystack.utils.hf import HFGenerationAPIType
47

48
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
49
                ChatMessage.from_user("What's Natural Language Processing?")]
50

51
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
52
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
53
    api_type = "serverless_inference_api" # this is equivalent to the above
54

55
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
56
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
57
                                            token=Secret.from_token("<your-api-key>"))
58

59
    result = generator.run(messages)
60
    print(result)
61
    ```
62

63
    #### With paid inference endpoints
64

65
    ```python
66
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
67
    from haystack.dataclasses import ChatMessage
68
    from haystack.utils import Secret
69

70
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
71
                ChatMessage.from_user("What's Natural Language Processing?")]
72

73
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
74
                                            api_params={"url": "<your-inference-endpoint-url>"},
75
                                            token=Secret.from_token("<your-api-key>"))
76

77
    result = generator.run(messages)
78
    print(result)
79

80
    #### With self-hosted text generation inference
81

82
    ```python
83
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
84
    from haystack.dataclasses import ChatMessage
85

86
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
87
                ChatMessage.from_user("What's Natural Language Processing?")]
88

89
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
90
                                            api_params={"url": "http://localhost:8080"})
91

92
    result = generator.run(messages)
93
    print(result)
94
    ```
95
    """
96

97
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
98
        self,
99
        api_type: Union[HFGenerationAPIType, str],
100
        api_params: Dict[str, str],
101
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
102
        generation_kwargs: Optional[Dict[str, Any]] = None,
103
        stop_words: Optional[List[str]] = None,
104
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
105
        tools: Optional[List[Tool]] = None,
106
    ):
107
        """
108
        Initialize the HuggingFaceAPIChatGenerator instance.
109

110
        :param api_type:
111
            The type of Hugging Face API to use. Available types:
112
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
113
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
114
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
115
        :param api_params:
116
            A dictionary with the following keys:
117
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
118
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
119
            `TEXT_GENERATION_INFERENCE`.
120
        :param token:
121
            The Hugging Face token to use as HTTP bearer authorization.
122
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
123
        :param generation_kwargs:
124
            A dictionary with keyword arguments to customize text generation.
125
                Some examples: `max_tokens`, `temperature`, `top_p`.
126
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
127
        :param stop_words:
128
            An optional list of strings representing the stop words.
129
        :param streaming_callback:
130
            An optional callable for handling streaming responses.
131
        :param tools:
132
            A list of tools for which the model can prepare calls.
133
            The chosen model should support tool/function calling, according to the model card.
134
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
135
            unexpected behavior.
136
        """
137

138
        huggingface_hub_import.check()
1✔
139

140
        if isinstance(api_type, str):
1✔
141
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
142

143
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
144
            model = api_params.get("model")
1✔
145
            if model is None:
1✔
146
                raise ValueError(
1✔
147
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
148
                )
149
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
150
            model_or_url = model
1✔
151
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
152
            url = api_params.get("url")
1✔
153
            if url is None:
1✔
154
                msg = (
1✔
155
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
156
                    "in `api_params`."
157
                )
158
                raise ValueError(msg)
1✔
159
            if not is_valid_http_url(url):
1✔
160
                raise ValueError(f"Invalid URL: {url}")
1✔
161
            model_or_url = url
1✔
162
        else:
163
            msg = f"Unknown api_type {api_type}"
×
164
            raise ValueError(msg)
×
165

166
        if tools and streaming_callback is not None:
1✔
167
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
168
        _check_duplicate_tool_names(tools)
1✔
169

170
        # handle generation kwargs setup
171
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
172
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
173
        generation_kwargs["stop"].extend(stop_words or [])
1✔
174
        generation_kwargs.setdefault("max_tokens", 512)
1✔
175

176
        self.api_type = api_type
1✔
177
        self.api_params = api_params
1✔
178
        self.token = token
1✔
179
        self.generation_kwargs = generation_kwargs
1✔
180
        self.streaming_callback = streaming_callback
1✔
181
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
182
        self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
183
        self.tools = tools
1✔
184

185
    def to_dict(self) -> Dict[str, Any]:
1✔
186
        """
187
        Serialize this component to a dictionary.
188

189
        :returns:
190
            A dictionary containing the serialized component.
191
        """
192
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
193
        serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
1✔
194
        return default_to_dict(
1✔
195
            self,
196
            api_type=str(self.api_type),
197
            api_params=self.api_params,
198
            token=self.token.to_dict() if self.token else None,
199
            generation_kwargs=self.generation_kwargs,
200
            streaming_callback=callback_name,
201
            tools=serialized_tools,
202
        )
203

204
    @classmethod
1✔
205
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
206
        """
207
        Deserialize this component from a dictionary.
208
        """
209
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
210
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
211
        init_params = data.get("init_parameters", {})
1✔
212
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
213
        if serialized_callback_handler:
1✔
214
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
215
        return default_from_dict(cls, data)
1✔
216

217
    @component.output_types(replies=List[ChatMessage])
1✔
218
    def run(
1✔
219
        self,
220
        messages: List[ChatMessage],
221
        generation_kwargs: Optional[Dict[str, Any]] = None,
222
        tools: Optional[List[Tool]] = None,
223
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
224
    ):
225
        """
226
        Invoke the text generation inference based on the provided messages and generation parameters.
227

228
        :param messages:
229
            A list of ChatMessage objects representing the input messages.
230
        :param generation_kwargs:
231
            Additional keyword arguments for text generation.
232
        :param tools:
233
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
234
            during component initialization.
235
        :param streaming_callback:
236
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
237
            parameter set during component initialization.
238
        :returns: A dictionary with the following keys:
239
            - `replies`: A list containing the generated responses as ChatMessage objects.
240
        """
241

242
        # update generation kwargs by merging with the default ones
243
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
244

245
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
246

247
        tools = tools or self.tools
1✔
248
        if tools and self.streaming_callback:
1✔
249
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
250
        _check_duplicate_tool_names(tools)
1✔
251

252
        # validate and select the streaming callback
253
        streaming_callback = select_streaming_callback(
1✔
254
            self.streaming_callback, streaming_callback, requires_async=False
255
        )  # type: ignore
256

257
        if streaming_callback:
1✔
258
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
259

260
        hf_tools = None
1✔
261
        if tools:
1✔
262
            hf_tools = [
1✔
263
                ChatCompletionInputTool(
264
                    function=ChatCompletionInputFunctionDefinition(
265
                        name=tool.name, description=tool.description, arguments=tool.parameters
266
                    ),
267
                    type="function",
268
                )
269
                for tool in tools
270
            ]
271
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
272

273
    @component.output_types(replies=List[ChatMessage])
1✔
274
    async def run_async(
1✔
275
        self,
276
        messages: List[ChatMessage],
277
        generation_kwargs: Optional[Dict[str, Any]] = None,
278
        tools: Optional[List[Tool]] = None,
279
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
280
    ):
281
        """
282
        Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
283

284
        This is the asynchronous version of the `run` method. It has the same parameters
285
        and return values but can be used with `await` in an async code.
286

287
        :param messages:
288
            A list of ChatMessage objects representing the input messages.
289
        :param generation_kwargs:
290
            Additional keyword arguments for text generation.
291
        :param tools:
292
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
293
            during component initialization.
294
        :param streaming_callback:
295
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
296
            parameter set during component initialization.
297
        :returns: A dictionary with the following keys:
298
            - `replies`: A list containing the generated responses as ChatMessage objects.
299
        """
300

301
        # update generation kwargs by merging with the default ones
302
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
303

304
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
305

306
        tools = tools or self.tools
1✔
307
        if tools and self.streaming_callback:
1✔
308
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
309
        _check_duplicate_tool_names(tools)
1✔
310

311
        # validate and select the streaming callback
312
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)  # type: ignore
1✔
313

314
        if streaming_callback:
1✔
315
            return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
1✔
316

317
        hf_tools = None
1✔
318
        if tools:
1✔
319
            hf_tools = [
1✔
320
                ChatCompletionInputTool(
321
                    function=ChatCompletionInputFunctionDefinition(
322
                        name=tool.name, description=tool.description, arguments=tool.parameters
323
                    ),
324
                    type="function",
325
                )
326
                for tool in tools
327
            ]
328
        return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
1✔
329

330
    def _run_streaming(
1✔
331
        self,
332
        messages: List[Dict[str, str]],
333
        generation_kwargs: Dict[str, Any],
334
        streaming_callback: Callable[[StreamingChunk], None],
335
    ):
336
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
337
            messages, stream=True, **generation_kwargs
338
        )
339

340
        generated_text = ""
1✔
341
        first_chunk_time = None
1✔
342

343
        for chunk in api_output:
1✔
344
            # n is unused, so the API always returns only one choice
345
            # the argument is probably allowed for compatibility with OpenAI
346
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
347
            choice = chunk.choices[0]
1✔
348

349
            text = choice.delta.content or ""
1✔
350
            generated_text += text
1✔
351

352
            finish_reason = choice.finish_reason
1✔
353

354
            meta: Dict[str, Any] = {}
1✔
355
            if finish_reason:
1✔
356
                meta["finish_reason"] = finish_reason
1✔
357

358
            if first_chunk_time is None:
1✔
359
                first_chunk_time = datetime.now().isoformat()
1✔
360

361
            stream_chunk = StreamingChunk(text, meta)
1✔
362
            streaming_callback(stream_chunk)
1✔
363

364
        meta.update(
1✔
365
            {
366
                "model": self._client.model,
367
                "finish_reason": finish_reason,
368
                "index": 0,
369
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
370
                "completion_start_time": first_chunk_time,
371
            }
372
        )
373

374
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
375

376
        return {"replies": [message]}
1✔
377

378
    def _run_non_streaming(
1✔
379
        self,
380
        messages: List[Dict[str, str]],
381
        generation_kwargs: Dict[str, Any],
382
        tools: Optional[List["ChatCompletionInputTool"]] = None,
383
    ) -> Dict[str, List[ChatMessage]]:
384
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
385
            messages=messages, tools=tools, **generation_kwargs
386
        )
387

388
        if len(api_chat_output.choices) == 0:
1✔
389
            return {"replies": []}
×
390

391
        # n is unused, so the API always returns only one choice
392
        # the argument is probably allowed for compatibility with OpenAI
393
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
394
        choice = api_chat_output.choices[0]
1✔
395

396
        text = choice.message.content
1✔
397
        tool_calls = []
1✔
398

399
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
400
            for hfapi_tc in hfapi_tool_calls:
1✔
401
                tool_call = ToolCall(
1✔
402
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
403
                )
404
                tool_calls.append(tool_call)
1✔
405

406
        meta: Dict[str, Any] = {
1✔
407
            "model": self._client.model,
408
            "finish_reason": choice.finish_reason,
409
            "index": choice.index,
410
        }
411

412
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
413
        if api_chat_output.usage:
1✔
414
            usage = {
1✔
415
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
416
                "completion_tokens": api_chat_output.usage.completion_tokens,
417
            }
418
        meta["usage"] = usage
1✔
419

420
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
421
        return {"replies": [message]}
1✔
422

423
    async def _run_streaming_async(
1✔
424
        self,
425
        messages: List[Dict[str, str]],
426
        generation_kwargs: Dict[str, Any],
427
        streaming_callback: Callable[[StreamingChunk], None],
428
    ):
429
        api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
1✔
430
            messages, stream=True, **generation_kwargs
431
        )
432

433
        generated_text = ""
1✔
434
        first_chunk_time = None
1✔
435

436
        async for chunk in api_output:
1✔
437
            choice = chunk.choices[0]
1✔
438

439
            text = choice.delta.content or ""
1✔
440
            generated_text += text
1✔
441

442
            finish_reason = choice.finish_reason
1✔
443

444
            meta: Dict[str, Any] = {}
1✔
445
            if finish_reason:
1✔
446
                meta["finish_reason"] = finish_reason
1✔
447

448
            if first_chunk_time is None:
1✔
449
                first_chunk_time = datetime.now().isoformat()
1✔
450

451
            stream_chunk = StreamingChunk(text, meta)
1✔
452
            await streaming_callback(stream_chunk)  # type: ignore
1✔
453

454
        meta.update(
1✔
455
            {
456
                "model": self._async_client.model,
457
                "finish_reason": finish_reason,
458
                "index": 0,
459
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},
460
                "completion_start_time": first_chunk_time,
461
            }
462
        )
463

464
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
465
        return {"replies": [message]}
1✔
466

467
    async def _run_non_streaming_async(
1✔
468
        self,
469
        messages: List[Dict[str, str]],
470
        generation_kwargs: Dict[str, Any],
471
        tools: Optional[List["ChatCompletionInputTool"]] = None,
472
    ) -> Dict[str, List[ChatMessage]]:
473
        api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
1✔
474
            messages=messages, tools=tools, **generation_kwargs
475
        )
476

477
        if len(api_chat_output.choices) == 0:
1✔
478
            return {"replies": []}
×
479

480
        choice = api_chat_output.choices[0]
1✔
481

482
        text = choice.message.content
1✔
483
        tool_calls = []
1✔
484

485
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
486
            for hfapi_tc in hfapi_tool_calls:
1✔
487
                tool_call = ToolCall(
1✔
488
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
489
                )
490
                tool_calls.append(tool_call)
1✔
491

492
        meta: Dict[str, Any] = {
1✔
493
            "model": self._async_client.model,
494
            "finish_reason": choice.finish_reason,
495
            "index": choice.index,
496
        }
497

498
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
499
        if api_chat_output.usage:
1✔
500
            usage = {
1✔
501
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
502
                "completion_tokens": api_chat_output.usage.completion_tokens,
503
            }
504
        meta["usage"] = usage
1✔
505

506
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
507
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc