• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14200685885

01 Apr 2025 04:13PM UTC coverage: 90.35% (+0.001%) from 90.349%
14200685885

push

github

web-flow
Consolidate the use of `select_streaming_callback` utility in OpenAI and Azure ChatGenerators (#9156)

* Always use select_streaming_callback on ChatGenerators ; Update types ; Remove unneeded type: ignore

* Add release note

* Remove other unneeded type: ignore

10542 of 11668 relevant lines covered (90.35%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.39
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1✔
10
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
11
from haystack.lazy_imports import LazyImport
1✔
12
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
13
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
14
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
15
from haystack.utils.url_validation import is_valid_http_url
1✔
16

17
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
18
    from huggingface_hub import (
1✔
19
        AsyncInferenceClient,
20
        ChatCompletionInputFunctionDefinition,
21
        ChatCompletionInputTool,
22
        ChatCompletionOutput,
23
        ChatCompletionStreamOutput,
24
        InferenceClient,
25
    )
26

27

28
@component
1✔
29
class HuggingFaceAPIChatGenerator:
1✔
30
    """
31
    Completes chats using Hugging Face APIs.
32

33
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
34
    format for input and output. Use it to generate text with Hugging Face APIs:
35
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
36
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
37
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
38

39
    ### Usage examples
40

41
    #### With the free serverless inference API
42

43
    ```python
44
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
45
    from haystack.dataclasses import ChatMessage
46
    from haystack.utils import Secret
47
    from haystack.utils.hf import HFGenerationAPIType
48

49
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
50
                ChatMessage.from_user("What's Natural Language Processing?")]
51

52
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
53
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
54
    api_type = "serverless_inference_api" # this is equivalent to the above
55

56
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
57
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
58
                                            token=Secret.from_token("<your-api-key>"))
59

60
    result = generator.run(messages)
61
    print(result)
62
    ```
63

64
    #### With paid inference endpoints
65

66
    ```python
67
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
68
    from haystack.dataclasses import ChatMessage
69
    from haystack.utils import Secret
70

71
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
72
                ChatMessage.from_user("What's Natural Language Processing?")]
73

74
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
75
                                            api_params={"url": "<your-inference-endpoint-url>"},
76
                                            token=Secret.from_token("<your-api-key>"))
77

78
    result = generator.run(messages)
79
    print(result)
80

81
    #### With self-hosted text generation inference
82

83
    ```python
84
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
85
    from haystack.dataclasses import ChatMessage
86

87
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
88
                ChatMessage.from_user("What's Natural Language Processing?")]
89

90
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
91
                                            api_params={"url": "http://localhost:8080"})
92

93
    result = generator.run(messages)
94
    print(result)
95
    ```
96
    """
97

98
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
99
        self,
100
        api_type: Union[HFGenerationAPIType, str],
101
        api_params: Dict[str, str],
102
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
103
        generation_kwargs: Optional[Dict[str, Any]] = None,
104
        stop_words: Optional[List[str]] = None,
105
        streaming_callback: Optional[StreamingCallbackT] = None,
106
        tools: Optional[List[Tool]] = None,
107
    ):
108
        """
109
        Initialize the HuggingFaceAPIChatGenerator instance.
110

111
        :param api_type:
112
            The type of Hugging Face API to use. Available types:
113
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
114
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
115
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
116
        :param api_params:
117
            A dictionary with the following keys:
118
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
119
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
120
            `TEXT_GENERATION_INFERENCE`.
121
        :param token:
122
            The Hugging Face token to use as HTTP bearer authorization.
123
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
124
        :param generation_kwargs:
125
            A dictionary with keyword arguments to customize text generation.
126
                Some examples: `max_tokens`, `temperature`, `top_p`.
127
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
128
        :param stop_words:
129
            An optional list of strings representing the stop words.
130
        :param streaming_callback:
131
            An optional callable for handling streaming responses.
132
        :param tools:
133
            A list of tools for which the model can prepare calls.
134
            The chosen model should support tool/function calling, according to the model card.
135
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
136
            unexpected behavior.
137
        """
138

139
        huggingface_hub_import.check()
1✔
140

141
        if isinstance(api_type, str):
1✔
142
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
143

144
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
145
            model = api_params.get("model")
1✔
146
            if model is None:
1✔
147
                raise ValueError(
1✔
148
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
149
                )
150
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
151
            model_or_url = model
1✔
152
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
153
            url = api_params.get("url")
1✔
154
            if url is None:
1✔
155
                msg = (
1✔
156
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
157
                    "in `api_params`."
158
                )
159
                raise ValueError(msg)
1✔
160
            if not is_valid_http_url(url):
1✔
161
                raise ValueError(f"Invalid URL: {url}")
1✔
162
            model_or_url = url
1✔
163
        else:
164
            msg = f"Unknown api_type {api_type}"
×
165
            raise ValueError(msg)
×
166

167
        if tools and streaming_callback is not None:
1✔
168
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
169
        _check_duplicate_tool_names(tools)
1✔
170

171
        # handle generation kwargs setup
172
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
173
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
174
        generation_kwargs["stop"].extend(stop_words or [])
1✔
175
        generation_kwargs.setdefault("max_tokens", 512)
1✔
176

177
        self.api_type = api_type
1✔
178
        self.api_params = api_params
1✔
179
        self.token = token
1✔
180
        self.generation_kwargs = generation_kwargs
1✔
181
        self.streaming_callback = streaming_callback
1✔
182
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
183
        self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
184
        self.tools = tools
1✔
185

186
    def to_dict(self) -> Dict[str, Any]:
1✔
187
        """
188
        Serialize this component to a dictionary.
189

190
        :returns:
191
            A dictionary containing the serialized component.
192
        """
193
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
194
        serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
1✔
195
        return default_to_dict(
1✔
196
            self,
197
            api_type=str(self.api_type),
198
            api_params=self.api_params,
199
            token=self.token.to_dict() if self.token else None,
200
            generation_kwargs=self.generation_kwargs,
201
            streaming_callback=callback_name,
202
            tools=serialized_tools,
203
        )
204

205
    @classmethod
1✔
206
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
207
        """
208
        Deserialize this component from a dictionary.
209
        """
210
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
211
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
212
        init_params = data.get("init_parameters", {})
1✔
213
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
214
        if serialized_callback_handler:
1✔
215
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
216
        return default_from_dict(cls, data)
1✔
217

218
    @component.output_types(replies=List[ChatMessage])
1✔
219
    def run(
1✔
220
        self,
221
        messages: List[ChatMessage],
222
        generation_kwargs: Optional[Dict[str, Any]] = None,
223
        tools: Optional[List[Tool]] = None,
224
        streaming_callback: Optional[StreamingCallbackT] = None,
225
    ):
226
        """
227
        Invoke the text generation inference based on the provided messages and generation parameters.
228

229
        :param messages:
230
            A list of ChatMessage objects representing the input messages.
231
        :param generation_kwargs:
232
            Additional keyword arguments for text generation.
233
        :param tools:
234
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
235
            during component initialization.
236
        :param streaming_callback:
237
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
238
            parameter set during component initialization.
239
        :returns: A dictionary with the following keys:
240
            - `replies`: A list containing the generated responses as ChatMessage objects.
241
        """
242

243
        # update generation kwargs by merging with the default ones
244
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
245

246
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
247

248
        tools = tools or self.tools
1✔
249
        if tools and self.streaming_callback:
1✔
250
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
251
        _check_duplicate_tool_names(tools)
1✔
252

253
        # validate and select the streaming callback
254
        streaming_callback = select_streaming_callback(
1✔
255
            self.streaming_callback, streaming_callback, requires_async=False
256
        )
257

258
        if streaming_callback:
1✔
259
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
260

261
        hf_tools = None
1✔
262
        if tools:
1✔
263
            hf_tools = [
1✔
264
                ChatCompletionInputTool(
265
                    function=ChatCompletionInputFunctionDefinition(
266
                        name=tool.name, description=tool.description, arguments=tool.parameters
267
                    ),
268
                    type="function",
269
                )
270
                for tool in tools
271
            ]
272
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
273

274
    @component.output_types(replies=List[ChatMessage])
1✔
275
    async def run_async(
1✔
276
        self,
277
        messages: List[ChatMessage],
278
        generation_kwargs: Optional[Dict[str, Any]] = None,
279
        tools: Optional[List[Tool]] = None,
280
        streaming_callback: Optional[StreamingCallbackT] = None,
281
    ):
282
        """
283
        Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
284

285
        This is the asynchronous version of the `run` method. It has the same parameters
286
        and return values but can be used with `await` in an async code.
287

288
        :param messages:
289
            A list of ChatMessage objects representing the input messages.
290
        :param generation_kwargs:
291
            Additional keyword arguments for text generation.
292
        :param tools:
293
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
294
            during component initialization.
295
        :param streaming_callback:
296
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
297
            parameter set during component initialization.
298
        :returns: A dictionary with the following keys:
299
            - `replies`: A list containing the generated responses as ChatMessage objects.
300
        """
301

302
        # update generation kwargs by merging with the default ones
303
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
304

305
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
306

307
        tools = tools or self.tools
1✔
308
        if tools and self.streaming_callback:
1✔
309
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
310
        _check_duplicate_tool_names(tools)
1✔
311

312
        # validate and select the streaming callback
313
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
314

315
        if streaming_callback:
1✔
316
            return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
1✔
317

318
        hf_tools = None
1✔
319
        if tools:
1✔
320
            hf_tools = [
1✔
321
                ChatCompletionInputTool(
322
                    function=ChatCompletionInputFunctionDefinition(
323
                        name=tool.name, description=tool.description, arguments=tool.parameters
324
                    ),
325
                    type="function",
326
                )
327
                for tool in tools
328
            ]
329
        return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
1✔
330

331
    def _run_streaming(
1✔
332
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
333
    ):
334
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
335
            messages, stream=True, **generation_kwargs
336
        )
337

338
        generated_text = ""
1✔
339
        first_chunk_time = None
1✔
340

341
        for chunk in api_output:
1✔
342
            # n is unused, so the API always returns only one choice
343
            # the argument is probably allowed for compatibility with OpenAI
344
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
345
            choice = chunk.choices[0]
1✔
346

347
            text = choice.delta.content or ""
1✔
348
            generated_text += text
1✔
349

350
            finish_reason = choice.finish_reason
1✔
351

352
            meta: Dict[str, Any] = {}
1✔
353
            if finish_reason:
1✔
354
                meta["finish_reason"] = finish_reason
1✔
355

356
            if first_chunk_time is None:
1✔
357
                first_chunk_time = datetime.now().isoformat()
1✔
358

359
            stream_chunk = StreamingChunk(text, meta)
1✔
360
            streaming_callback(stream_chunk)
1✔
361

362
        meta.update(
1✔
363
            {
364
                "model": self._client.model,
365
                "finish_reason": finish_reason,
366
                "index": 0,
367
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
368
                "completion_start_time": first_chunk_time,
369
            }
370
        )
371

372
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
373

374
        return {"replies": [message]}
1✔
375

376
    def _run_non_streaming(
1✔
377
        self,
378
        messages: List[Dict[str, str]],
379
        generation_kwargs: Dict[str, Any],
380
        tools: Optional[List["ChatCompletionInputTool"]] = None,
381
    ) -> Dict[str, List[ChatMessage]]:
382
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
383
            messages=messages, tools=tools, **generation_kwargs
384
        )
385

386
        if len(api_chat_output.choices) == 0:
1✔
387
            return {"replies": []}
×
388

389
        # n is unused, so the API always returns only one choice
390
        # the argument is probably allowed for compatibility with OpenAI
391
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
392
        choice = api_chat_output.choices[0]
1✔
393

394
        text = choice.message.content
1✔
395
        tool_calls = []
1✔
396

397
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
398
            for hfapi_tc in hfapi_tool_calls:
1✔
399
                tool_call = ToolCall(
1✔
400
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
401
                )
402
                tool_calls.append(tool_call)
1✔
403

404
        meta: Dict[str, Any] = {
1✔
405
            "model": self._client.model,
406
            "finish_reason": choice.finish_reason,
407
            "index": choice.index,
408
        }
409

410
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
411
        if api_chat_output.usage:
1✔
412
            usage = {
1✔
413
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
414
                "completion_tokens": api_chat_output.usage.completion_tokens,
415
            }
416
        meta["usage"] = usage
1✔
417

418
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
419
        return {"replies": [message]}
1✔
420

421
    async def _run_streaming_async(
1✔
422
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
423
    ):
424
        api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
1✔
425
            messages, stream=True, **generation_kwargs
426
        )
427

428
        generated_text = ""
1✔
429
        first_chunk_time = None
1✔
430

431
        async for chunk in api_output:
1✔
432
            choice = chunk.choices[0]
1✔
433

434
            text = choice.delta.content or ""
1✔
435
            generated_text += text
1✔
436

437
            finish_reason = choice.finish_reason
1✔
438

439
            meta: Dict[str, Any] = {}
1✔
440
            if finish_reason:
1✔
441
                meta["finish_reason"] = finish_reason
1✔
442

443
            if first_chunk_time is None:
1✔
444
                first_chunk_time = datetime.now().isoformat()
1✔
445

446
            stream_chunk = StreamingChunk(text, meta)
1✔
447
            await streaming_callback(stream_chunk)  # type: ignore
1✔
448

449
        meta.update(
1✔
450
            {
451
                "model": self._async_client.model,
452
                "finish_reason": finish_reason,
453
                "index": 0,
454
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},
455
                "completion_start_time": first_chunk_time,
456
            }
457
        )
458

459
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
460
        return {"replies": [message]}
1✔
461

462
    async def _run_non_streaming_async(
1✔
463
        self,
464
        messages: List[Dict[str, str]],
465
        generation_kwargs: Dict[str, Any],
466
        tools: Optional[List["ChatCompletionInputTool"]] = None,
467
    ) -> Dict[str, List[ChatMessage]]:
468
        api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
1✔
469
            messages=messages, tools=tools, **generation_kwargs
470
        )
471

472
        if len(api_chat_output.choices) == 0:
1✔
473
            return {"replies": []}
×
474

475
        choice = api_chat_output.choices[0]
1✔
476

477
        text = choice.message.content
1✔
478
        tool_calls = []
1✔
479

480
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
481
            for hfapi_tc in hfapi_tool_calls:
1✔
482
                tool_call = ToolCall(
1✔
483
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
484
                )
485
                tool_calls.append(tool_call)
1✔
486

487
        meta: Dict[str, Any] = {
1✔
488
            "model": self._async_client.model,
489
            "finish_reason": choice.finish_reason,
490
            "index": choice.index,
491
        }
492

493
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
494
        if api_chat_output.usage:
1✔
495
            usage = {
1✔
496
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
497
                "completion_tokens": api_chat_output.usage.completion_tokens,
498
            }
499
        meta["usage"] = usage
1✔
500

501
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
502
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc