• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14309102962

07 Apr 2025 12:24PM UTC coverage: 90.086% (+0.001%) from 90.085%
14309102962

Pull #9177

github

web-flow
Merge 29241278d into 63781afd8
Pull Request #9177: feat: Add Toolset support in ChatGenerator(s)

10641 of 11812 relevant lines covered (90.09%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.32
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1✔
10
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
11
from haystack.lazy_imports import LazyImport
1✔
12
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
13
from haystack.tools.toolset import Toolset
1✔
14
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
15
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
16
from haystack.utils.misc import serialize_tools_or_toolset
1✔
17
from haystack.utils.url_validation import is_valid_http_url
1✔
18

19
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
20
    from huggingface_hub import (
1✔
21
        AsyncInferenceClient,
22
        ChatCompletionInputFunctionDefinition,
23
        ChatCompletionInputTool,
24
        ChatCompletionOutput,
25
        ChatCompletionStreamOutput,
26
        InferenceClient,
27
    )
28

29

30
@component
1✔
31
class HuggingFaceAPIChatGenerator:
1✔
32
    """
33
    Completes chats using Hugging Face APIs.
34

35
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
36
    format for input and output. Use it to generate text with Hugging Face APIs:
37
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
38
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
39
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
40

41
    ### Usage examples
42

43
    #### With the free serverless inference API
44

45
    ```python
46
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
47
    from haystack.dataclasses import ChatMessage
48
    from haystack.utils import Secret
49
    from haystack.utils.hf import HFGenerationAPIType
50

51
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
52
                ChatMessage.from_user("What's Natural Language Processing?")]
53

54
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
55
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
56
    api_type = "serverless_inference_api" # this is equivalent to the above
57

58
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
59
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
60
                                            token=Secret.from_token("<your-api-key>"))
61

62
    result = generator.run(messages)
63
    print(result)
64
    ```
65

66
    #### With paid inference endpoints
67

68
    ```python
69
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
70
    from haystack.dataclasses import ChatMessage
71
    from haystack.utils import Secret
72

73
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
74
                ChatMessage.from_user("What's Natural Language Processing?")]
75

76
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
77
                                            api_params={"url": "<your-inference-endpoint-url>"},
78
                                            token=Secret.from_token("<your-api-key>"))
79

80
    result = generator.run(messages)
81
    print(result)
82

83
    #### With self-hosted text generation inference
84

85
    ```python
86
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
87
    from haystack.dataclasses import ChatMessage
88

89
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
90
                ChatMessage.from_user("What's Natural Language Processing?")]
91

92
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
93
                                            api_params={"url": "http://localhost:8080"})
94

95
    result = generator.run(messages)
96
    print(result)
97
    ```
98
    """
99

100
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
101
        self,
102
        api_type: Union[HFGenerationAPIType, str],
103
        api_params: Dict[str, str],
104
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
105
        generation_kwargs: Optional[Dict[str, Any]] = None,
106
        stop_words: Optional[List[str]] = None,
107
        streaming_callback: Optional[StreamingCallbackT] = None,
108
        tools: Optional[Union[List[Tool], Toolset]] = None,
109
    ):
110
        """
111
        Initialize the HuggingFaceAPIChatGenerator instance.
112

113
        :param api_type:
114
            The type of Hugging Face API to use. Available types:
115
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
116
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
117
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
118
        :param api_params:
119
            A dictionary with the following keys:
120
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
121
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
122
            `TEXT_GENERATION_INFERENCE`.
123
        :param token:
124
            The Hugging Face token to use as HTTP bearer authorization.
125
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
126
        :param generation_kwargs:
127
            A dictionary with keyword arguments to customize text generation.
128
                Some examples: `max_tokens`, `temperature`, `top_p`.
129
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
130
        :param stop_words:
131
            An optional list of strings representing the stop words.
132
        :param streaming_callback:
133
            An optional callable for handling streaming responses.
134
        :param tools:
135
            A list of tools or a Toolset for which the model can prepare calls.
136
            The chosen model should support tool/function calling, according to the model card.
137
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
138
            unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
139
        """
140

141
        huggingface_hub_import.check()
1✔
142

143
        if isinstance(api_type, str):
1✔
144
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
145

146
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
147
            model = api_params.get("model")
1✔
148
            if model is None:
1✔
149
                raise ValueError(
1✔
150
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
151
                )
152
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
153
            model_or_url = model
1✔
154
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
155
            url = api_params.get("url")
1✔
156
            if url is None:
1✔
157
                msg = (
1✔
158
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
159
                    "in `api_params`."
160
                )
161
                raise ValueError(msg)
1✔
162
            if not is_valid_http_url(url):
1✔
163
                raise ValueError(f"Invalid URL: {url}")
1✔
164
            model_or_url = url
1✔
165
        else:
166
            msg = f"Unknown api_type {api_type}"
×
167
            raise ValueError(msg)
×
168

169
        if tools and streaming_callback is not None:
1✔
170
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
171
        _check_duplicate_tool_names(list(tools or []))
1✔
172

173
        # handle generation kwargs setup
174
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
175
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
176
        generation_kwargs["stop"].extend(stop_words or [])
1✔
177
        generation_kwargs.setdefault("max_tokens", 512)
1✔
178

179
        self.api_type = api_type
1✔
180
        self.api_params = api_params
1✔
181
        self.token = token
1✔
182
        self.generation_kwargs = generation_kwargs
1✔
183
        self.streaming_callback = streaming_callback
1✔
184
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
185
        self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
186
        self.tools = tools
1✔
187

188
    def to_dict(self) -> Dict[str, Any]:
1✔
189
        """
190
        Serialize this component to a dictionary.
191

192
        :returns:
193
            A dictionary containing the serialized component.
194
        """
195
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
196
        return default_to_dict(
1✔
197
            self,
198
            api_type=str(self.api_type),
199
            api_params=self.api_params,
200
            token=self.token.to_dict() if self.token else None,
201
            generation_kwargs=self.generation_kwargs,
202
            streaming_callback=callback_name,
203
            tools=serialize_tools_or_toolset(self.tools),
204
        )
205

206
    @classmethod
1✔
207
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
208
        """
209
        Deserialize this component from a dictionary.
210
        """
211
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
212
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
213
        init_params = data.get("init_parameters", {})
1✔
214
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
215
        if serialized_callback_handler:
1✔
216
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
217
        return default_from_dict(cls, data)
1✔
218

219
    @component.output_types(replies=List[ChatMessage])
1✔
220
    def run(
1✔
221
        self,
222
        messages: List[ChatMessage],
223
        generation_kwargs: Optional[Dict[str, Any]] = None,
224
        tools: Optional[Union[List[Tool], Toolset]] = None,
225
        streaming_callback: Optional[StreamingCallbackT] = None,
226
    ):
227
        """
228
        Invoke the text generation inference based on the provided messages and generation parameters.
229

230
        :param messages:
231
            A list of ChatMessage objects representing the input messages.
232
        :param generation_kwargs:
233
            Additional keyword arguments for text generation.
234
        :param tools:
235
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override
236
            the `tools` parameter set during component initialization. This parameter can accept either a
237
            list of `Tool` objects or a `Toolset` instance.
238
        :param streaming_callback:
239
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
240
            parameter set during component initialization.
241
        :returns: A dictionary with the following keys:
242
            - `replies`: A list containing the generated responses as ChatMessage objects.
243
        """
244

245
        # update generation kwargs by merging with the default ones
246
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
247

248
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
249

250
        tools = tools or self.tools
1✔
251
        if tools and self.streaming_callback:
1✔
252
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
253
        _check_duplicate_tool_names(list(tools or []))
1✔
254

255
        # validate and select the streaming callback
256
        streaming_callback = select_streaming_callback(
1✔
257
            self.streaming_callback, streaming_callback, requires_async=False
258
        )
259

260
        if streaming_callback:
1✔
261
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
262

263
        hf_tools = None
1✔
264
        if tools:
1✔
265
            if isinstance(tools, Toolset):
1✔
266
                tools = list(tools)
×
267
            hf_tools = [
1✔
268
                ChatCompletionInputTool(
269
                    function=ChatCompletionInputFunctionDefinition(
270
                        name=tool.name, description=tool.description, arguments=tool.parameters
271
                    ),
272
                    type="function",
273
                )
274
                for tool in tools
275
            ]
276
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
277

278
    @component.output_types(replies=List[ChatMessage])
1✔
279
    async def run_async(
1✔
280
        self,
281
        messages: List[ChatMessage],
282
        generation_kwargs: Optional[Dict[str, Any]] = None,
283
        tools: Optional[Union[List[Tool], Toolset]] = None,
284
        streaming_callback: Optional[StreamingCallbackT] = None,
285
    ):
286
        """
287
        Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
288

289
        This is the asynchronous version of the `run` method. It has the same parameters
290
        and return values but can be used with `await` in an async code.
291

292
        :param messages:
293
            A list of ChatMessage objects representing the input messages.
294
        :param generation_kwargs:
295
            Additional keyword arguments for text generation.
296
        :param tools:
297
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
298
            parameter set during component initialization. This parameter can accept either a list of `Tool` objects
299
            or a `Toolset` instance.
300
        :param streaming_callback:
301
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
302
            parameter set during component initialization.
303
        :returns: A dictionary with the following keys:
304
            - `replies`: A list containing the generated responses as ChatMessage objects.
305
        """
306

307
        # update generation kwargs by merging with the default ones
308
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
309

310
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
311

312
        tools = tools or self.tools
1✔
313
        if tools and self.streaming_callback:
1✔
314
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
315
        _check_duplicate_tool_names(list(tools or []))
1✔
316

317
        # validate and select the streaming callback
318
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
319

320
        if streaming_callback:
1✔
321
            return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
1✔
322

323
        hf_tools = None
1✔
324
        if tools:
1✔
325
            if isinstance(tools, Toolset):
1✔
326
                tools = list(tools)
×
327
            hf_tools = [
1✔
328
                ChatCompletionInputTool(
329
                    function=ChatCompletionInputFunctionDefinition(
330
                        name=tool.name, description=tool.description, arguments=tool.parameters
331
                    ),
332
                    type="function",
333
                )
334
                for tool in tools
335
            ]
336
        return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
1✔
337

338
    def _run_streaming(
1✔
339
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
340
    ):
341
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
342
            messages, stream=True, **generation_kwargs
343
        )
344

345
        generated_text = ""
1✔
346
        first_chunk_time = None
1✔
347

348
        for chunk in api_output:
1✔
349
            # n is unused, so the API always returns only one choice
350
            # the argument is probably allowed for compatibility with OpenAI
351
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
352
            choice = chunk.choices[0]
1✔
353

354
            text = choice.delta.content or ""
1✔
355
            generated_text += text
1✔
356

357
            finish_reason = choice.finish_reason
1✔
358

359
            meta: Dict[str, Any] = {}
1✔
360
            if finish_reason:
1✔
361
                meta["finish_reason"] = finish_reason
1✔
362

363
            if first_chunk_time is None:
1✔
364
                first_chunk_time = datetime.now().isoformat()
1✔
365

366
            stream_chunk = StreamingChunk(text, meta)
1✔
367
            streaming_callback(stream_chunk)
1✔
368

369
        meta.update(
1✔
370
            {
371
                "model": self._client.model,
372
                "finish_reason": finish_reason,
373
                "index": 0,
374
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
375
                "completion_start_time": first_chunk_time,
376
            }
377
        )
378

379
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
380

381
        return {"replies": [message]}
1✔
382

383
    def _run_non_streaming(
1✔
384
        self,
385
        messages: List[Dict[str, str]],
386
        generation_kwargs: Dict[str, Any],
387
        tools: Optional[List["ChatCompletionInputTool"]] = None,
388
    ) -> Dict[str, List[ChatMessage]]:
389
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
390
            messages=messages, tools=tools, **generation_kwargs
391
        )
392

393
        if len(api_chat_output.choices) == 0:
1✔
394
            return {"replies": []}
×
395

396
        # n is unused, so the API always returns only one choice
397
        # the argument is probably allowed for compatibility with OpenAI
398
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
399
        choice = api_chat_output.choices[0]
1✔
400

401
        text = choice.message.content
1✔
402
        tool_calls = []
1✔
403

404
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
405
            for hfapi_tc in hfapi_tool_calls:
1✔
406
                tool_call = ToolCall(
1✔
407
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
408
                )
409
                tool_calls.append(tool_call)
1✔
410

411
        meta: Dict[str, Any] = {
1✔
412
            "model": self._client.model,
413
            "finish_reason": choice.finish_reason,
414
            "index": choice.index,
415
        }
416

417
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
418
        if api_chat_output.usage:
1✔
419
            usage = {
1✔
420
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
421
                "completion_tokens": api_chat_output.usage.completion_tokens,
422
            }
423
        meta["usage"] = usage
1✔
424

425
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
426
        return {"replies": [message]}
1✔
427

428
    async def _run_streaming_async(
1✔
429
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
430
    ):
431
        api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
1✔
432
            messages, stream=True, **generation_kwargs
433
        )
434

435
        generated_text = ""
1✔
436
        first_chunk_time = None
1✔
437

438
        async for chunk in api_output:
1✔
439
            choice = chunk.choices[0]
1✔
440

441
            text = choice.delta.content or ""
1✔
442
            generated_text += text
1✔
443

444
            finish_reason = choice.finish_reason
1✔
445

446
            meta: Dict[str, Any] = {}
1✔
447
            if finish_reason:
1✔
448
                meta["finish_reason"] = finish_reason
1✔
449

450
            if first_chunk_time is None:
1✔
451
                first_chunk_time = datetime.now().isoformat()
1✔
452

453
            stream_chunk = StreamingChunk(text, meta)
1✔
454
            await streaming_callback(stream_chunk)  # type: ignore
1✔
455

456
        meta.update(
1✔
457
            {
458
                "model": self._async_client.model,
459
                "finish_reason": finish_reason,
460
                "index": 0,
461
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},
462
                "completion_start_time": first_chunk_time,
463
            }
464
        )
465

466
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
467
        return {"replies": [message]}
1✔
468

469
    async def _run_non_streaming_async(
1✔
470
        self,
471
        messages: List[Dict[str, str]],
472
        generation_kwargs: Dict[str, Any],
473
        tools: Optional[List["ChatCompletionInputTool"]] = None,
474
    ) -> Dict[str, List[ChatMessage]]:
475
        api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
1✔
476
            messages=messages, tools=tools, **generation_kwargs
477
        )
478

479
        if len(api_chat_output.choices) == 0:
1✔
480
            return {"replies": []}
×
481

482
        choice = api_chat_output.choices[0]
1✔
483

484
        text = choice.message.content
1✔
485
        tool_calls = []
1✔
486

487
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
488
            for hfapi_tc in hfapi_tool_calls:
1✔
489
                tool_call = ToolCall(
1✔
490
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
491
                )
492
                tool_calls.append(tool_call)
1✔
493

494
        meta: Dict[str, Any] = {
1✔
495
            "model": self._async_client.model,
496
            "finish_reason": choice.finish_reason,
497
            "index": choice.index,
498
        }
499

500
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
501
        if api_chat_output.usage:
1✔
502
            usage = {
1✔
503
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
504
                "completion_tokens": api_chat_output.usage.completion_tokens,
505
            }
506
        meta["usage"] = usage
1✔
507

508
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
509
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc