• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13196574846

07 Feb 2025 09:04AM UTC coverage: 92.158% (+0.9%) from 91.299%
13196574846

Pull #8817

github

web-flow
Merge bfbca25b1 into 1785ea622
Pull Request #8817: fix: Update OpenAPIServiceConnector to new ChatMessage

9025 of 9793 relevant lines covered (92.16%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.46
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
12
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
13
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
14
from haystack.utils.url_validation import is_valid_http_url
1✔
15

16
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
17
    from huggingface_hub import (
1✔
18
        ChatCompletionInputFunctionDefinition,
19
        ChatCompletionInputTool,
20
        ChatCompletionOutput,
21
        ChatCompletionStreamOutput,
22
        InferenceClient,
23
    )
24

25

26
logger = logging.getLogger(__name__)
1✔
27

28

29
@component
1✔
30
class HuggingFaceAPIChatGenerator:
1✔
31
    """
32
    Completes chats using Hugging Face APIs.
33

34
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
35
    format for input and output. Use it to generate text with Hugging Face APIs:
36
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
37
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
38
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
39

40
    ### Usage examples
41

42
    #### With the free serverless inference API
43

44
    ```python
45
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
46
    from haystack.dataclasses import ChatMessage
47
    from haystack.utils import Secret
48
    from haystack.utils.hf import HFGenerationAPIType
49

50
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
51
                ChatMessage.from_user("What's Natural Language Processing?")]
52

53
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
54
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
55
    api_type = "serverless_inference_api" # this is equivalent to the above
56

57
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
58
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
59
                                            token=Secret.from_token("<your-api-key>"))
60

61
    result = generator.run(messages)
62
    print(result)
63
    ```
64

65
    #### With paid inference endpoints
66

67
    ```python
68
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
69
    from haystack.dataclasses import ChatMessage
70
    from haystack.utils import Secret
71

72
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
73
                ChatMessage.from_user("What's Natural Language Processing?")]
74

75
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
76
                                            api_params={"url": "<your-inference-endpoint-url>"},
77
                                            token=Secret.from_token("<your-api-key>"))
78

79
    result = generator.run(messages)
80
    print(result)
81

82
    #### With self-hosted text generation inference
83

84
    ```python
85
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
86
    from haystack.dataclasses import ChatMessage
87

88
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
89
                ChatMessage.from_user("What's Natural Language Processing?")]
90

91
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
92
                                            api_params={"url": "http://localhost:8080"})
93

94
    result = generator.run(messages)
95
    print(result)
96
    ```
97
    """
98

99
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
100
        self,
101
        api_type: Union[HFGenerationAPIType, str],
102
        api_params: Dict[str, str],
103
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
104
        generation_kwargs: Optional[Dict[str, Any]] = None,
105
        stop_words: Optional[List[str]] = None,
106
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
107
        tools: Optional[List[Tool]] = None,
108
    ):
109
        """
110
        Initialize the HuggingFaceAPIChatGenerator instance.
111

112
        :param api_type:
113
            The type of Hugging Face API to use. Available types:
114
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
115
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
116
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
117
        :param api_params:
118
            A dictionary with the following keys:
119
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
120
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
121
            `TEXT_GENERATION_INFERENCE`.
122
        :param token:
123
            The Hugging Face token to use as HTTP bearer authorization.
124
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
125
        :param generation_kwargs:
126
            A dictionary with keyword arguments to customize text generation.
127
                Some examples: `max_tokens`, `temperature`, `top_p`.
128
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
129
        :param stop_words:
130
            An optional list of strings representing the stop words.
131
        :param streaming_callback:
132
            An optional callable for handling streaming responses.
133
        :param tools:
134
            A list of tools for which the model can prepare calls.
135
            The chosen model should support tool/function calling, according to the model card.
136
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
137
            unexpected behavior.
138
        """
139

140
        huggingface_hub_import.check()
1✔
141

142
        if isinstance(api_type, str):
1✔
143
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
144

145
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
146
            model = api_params.get("model")
1✔
147
            if model is None:
1✔
148
                raise ValueError(
1✔
149
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
150
                )
151
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
152
            model_or_url = model
1✔
153
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
154
            url = api_params.get("url")
1✔
155
            if url is None:
1✔
156
                msg = (
1✔
157
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
158
                    "in `api_params`."
159
                )
160
                raise ValueError(msg)
1✔
161
            if not is_valid_http_url(url):
1✔
162
                raise ValueError(f"Invalid URL: {url}")
1✔
163
            model_or_url = url
1✔
164
        else:
165
            msg = f"Unknown api_type {api_type}"
×
166
            raise ValueError(msg)
×
167

168
        if tools and streaming_callback is not None:
1✔
169
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
170
        _check_duplicate_tool_names(tools)
1✔
171

172
        # handle generation kwargs setup
173
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
174
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
175
        generation_kwargs["stop"].extend(stop_words or [])
1✔
176
        generation_kwargs.setdefault("max_tokens", 512)
1✔
177

178
        self.api_type = api_type
1✔
179
        self.api_params = api_params
1✔
180
        self.token = token
1✔
181
        self.generation_kwargs = generation_kwargs
1✔
182
        self.streaming_callback = streaming_callback
1✔
183
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
184
        self.tools = tools
1✔
185

186
    def to_dict(self) -> Dict[str, Any]:
1✔
187
        """
188
        Serialize this component to a dictionary.
189

190
        :returns:
191
            A dictionary containing the serialized component.
192
        """
193
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
194
        serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
1✔
195
        return default_to_dict(
1✔
196
            self,
197
            api_type=str(self.api_type),
198
            api_params=self.api_params,
199
            token=self.token.to_dict() if self.token else None,
200
            generation_kwargs=self.generation_kwargs,
201
            streaming_callback=callback_name,
202
            tools=serialized_tools,
203
        )
204

205
    @classmethod
1✔
206
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
207
        """
208
        Deserialize this component from a dictionary.
209
        """
210
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
211
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
212
        init_params = data.get("init_parameters", {})
1✔
213
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
214
        if serialized_callback_handler:
1✔
215
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
216
        return default_from_dict(cls, data)
1✔
217

218
    @component.output_types(replies=List[ChatMessage])
1✔
219
    def run(
1✔
220
        self,
221
        messages: List[ChatMessage],
222
        generation_kwargs: Optional[Dict[str, Any]] = None,
223
        tools: Optional[List[Tool]] = None,
224
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
225
    ):
226
        """
227
        Invoke the text generation inference based on the provided messages and generation parameters.
228

229
        :param messages:
230
            A list of ChatMessage objects representing the input messages.
231
        :param generation_kwargs:
232
            Additional keyword arguments for text generation.
233
        :param tools:
234
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
235
            during component initialization.
236
        :param streaming_callback:
237
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
238
            parameter set during component initialization.
239
        :returns: A dictionary with the following keys:
240
            - `replies`: A list containing the generated responses as ChatMessage objects.
241
        """
242

243
        # update generation kwargs by merging with the default ones
244
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
245

246
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
247

248
        tools = tools or self.tools
1✔
249
        if tools and self.streaming_callback:
1✔
250
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
251
        _check_duplicate_tool_names(tools)
1✔
252

253
        streaming_callback = streaming_callback or self.streaming_callback
1✔
254
        if streaming_callback:
1✔
255
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
256

257
        hf_tools = None
1✔
258
        if tools:
1✔
259
            hf_tools = [
1✔
260
                ChatCompletionInputTool(
261
                    function=ChatCompletionInputFunctionDefinition(
262
                        name=tool.name, description=tool.description, arguments=tool.parameters
263
                    ),
264
                    type="function",
265
                )
266
                for tool in tools
267
            ]
268
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
269

270
    def _run_streaming(
1✔
271
        self,
272
        messages: List[Dict[str, str]],
273
        generation_kwargs: Dict[str, Any],
274
        streaming_callback: Callable[[StreamingChunk], None],
275
    ):
276
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
277
            messages, stream=True, **generation_kwargs
278
        )
279

280
        generated_text = ""
1✔
281
        first_chunk_time = None
1✔
282

283
        for chunk in api_output:
1✔
284
            # n is unused, so the API always returns only one choice
285
            # the argument is probably allowed for compatibility with OpenAI
286
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
287
            choice = chunk.choices[0]
1✔
288

289
            text = choice.delta.content or ""
1✔
290
            generated_text += text
1✔
291

292
            finish_reason = choice.finish_reason
1✔
293

294
            meta: Dict[str, Any] = {}
1✔
295
            if finish_reason:
1✔
296
                meta["finish_reason"] = finish_reason
1✔
297

298
            if first_chunk_time is None:
1✔
299
                first_chunk_time = datetime.now().isoformat()
1✔
300

301
            stream_chunk = StreamingChunk(text, meta)
1✔
302
            streaming_callback(stream_chunk)
1✔
303

304
        meta.update(
1✔
305
            {
306
                "model": self._client.model,
307
                "finish_reason": finish_reason,
308
                "index": 0,
309
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
310
                "completion_start_time": first_chunk_time,
311
            }
312
        )
313

314
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
315

316
        return {"replies": [message]}
1✔
317

318
    def _run_non_streaming(
1✔
319
        self,
320
        messages: List[Dict[str, str]],
321
        generation_kwargs: Dict[str, Any],
322
        tools: Optional[List["ChatCompletionInputTool"]] = None,
323
    ) -> Dict[str, List[ChatMessage]]:
324
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
325
            messages=messages, tools=tools, **generation_kwargs
326
        )
327

328
        if len(api_chat_output.choices) == 0:
1✔
329
            return {"replies": []}
×
330

331
        # n is unused, so the API always returns only one choice
332
        # the argument is probably allowed for compatibility with OpenAI
333
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
334
        choice = api_chat_output.choices[0]
1✔
335

336
        text = choice.message.content
1✔
337
        tool_calls = []
1✔
338

339
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
340
            for hfapi_tc in hfapi_tool_calls:
1✔
341
                tool_call = ToolCall(
1✔
342
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
343
                )
344
                tool_calls.append(tool_call)
1✔
345

346
        meta: Dict[str, Any] = {
1✔
347
            "model": self._client.model,
348
            "finish_reason": choice.finish_reason,
349
            "index": choice.index,
350
        }
351

352
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
353
        if api_chat_output.usage:
1✔
354
            usage = {
1✔
355
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
356
                "completion_tokens": api_chat_output.usage.completion_tokens,
357
            }
358
        meta["usage"] = usage
1✔
359

360
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
361
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc