• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 16068286960

04 Jul 2025 07:26AM UTC coverage: 90.432%. Remained the same
16068286960

Pull #9589

github

web-flow
Merge 8630827d6 into 050c98794
Pull Request #9589: chore: reenable some HF API tests + improve docstrings

11682 of 12918 relevant lines covered (90.43%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.79
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
from datetime import datetime
1✔
7
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging
1✔
10
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
1✔
11
from haystack.dataclasses import (
1✔
12
    AsyncStreamingCallbackT,
13
    ChatMessage,
14
    ComponentInfo,
15
    StreamingCallbackT,
16
    StreamingChunk,
17
    SyncStreamingCallbackT,
18
    ToolCall,
19
    select_streaming_callback,
20
)
21
from haystack.dataclasses.streaming_chunk import FinishReason
1✔
22
from haystack.lazy_imports import LazyImport
1✔
23
from haystack.tools import (
1✔
24
    Tool,
25
    Toolset,
26
    _check_duplicate_tool_names,
27
    deserialize_tools_or_toolset_inplace,
28
    serialize_tools_or_toolset,
29
)
30
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
31
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
32
from haystack.utils.url_validation import is_valid_http_url
1✔
33

34
logger = logging.getLogger(__name__)
1✔
35

36
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
37
    from huggingface_hub import (
1✔
38
        AsyncInferenceClient,
39
        ChatCompletionInputFunctionDefinition,
40
        ChatCompletionInputStreamOptions,
41
        ChatCompletionInputTool,
42
        ChatCompletionOutput,
43
        ChatCompletionOutputToolCall,
44
        ChatCompletionStreamOutput,
45
        ChatCompletionStreamOutputChoice,
46
        InferenceClient,
47
    )
48

49

50
def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[List["ChatCompletionOutputToolCall"]]) -> List[ToolCall]:
1✔
51
    """
52
    Convert HuggingFace API tool calls to a list of Haystack ToolCall.
53

54
    :param hfapi_tool_calls: The HuggingFace API tool calls to convert.
55
    :returns: A list of ToolCall objects.
56

57
    """
58
    if not hfapi_tool_calls:
1✔
59
        return []
1✔
60

61
    tool_calls = []
1✔
62

63
    for hfapi_tc in hfapi_tool_calls:
1✔
64
        hf_arguments = hfapi_tc.function.arguments
1✔
65

66
        arguments = None
1✔
67
        if isinstance(hf_arguments, dict):
1✔
68
            arguments = hf_arguments
1✔
69
        elif isinstance(hf_arguments, str):
1✔
70
            try:
1✔
71
                arguments = json.loads(hf_arguments)
1✔
72
            except json.JSONDecodeError:
1✔
73
                logger.warning(
1✔
74
                    "HuggingFace API returned a malformed JSON string for tool call arguments. This tool call "
75
                    "will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
76
                    _id=hfapi_tc.id,
77
                    _name=hfapi_tc.function.name,
78
                    _arguments=hf_arguments,
79
                )
80
        else:
81
            logger.warning(
1✔
82
                "HuggingFace API returned tool call arguments of type {_type}. Valid types are dict and str. This tool "
83
                "call will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
84
                _id=hfapi_tc.id,
85
                _name=hfapi_tc.function.name,
86
                _arguments=hf_arguments,
87
            )
88

89
        if arguments:
1✔
90
            tool_calls.append(ToolCall(tool_name=hfapi_tc.function.name, arguments=arguments, id=hfapi_tc.id))
1✔
91

92
    return tool_calls
1✔
93

94

95
def _convert_tools_to_hfapi_tools(
1✔
96
    tools: Optional[Union[List[Tool], Toolset]],
97
) -> Optional[List["ChatCompletionInputTool"]]:
98
    if not tools:
1✔
99
        return None
1✔
100

101
    # huggingface_hub<0.31.0 uses "arguments", huggingface_hub>=0.31.0 uses "parameters"
102
    parameters_name = "arguments" if hasattr(ChatCompletionInputFunctionDefinition, "arguments") else "parameters"
1✔
103

104
    hf_tools = []
1✔
105
    for tool in tools:
1✔
106
        hf_tools_args = {"name": tool.name, "description": tool.description, parameters_name: tool.parameters}
1✔
107

108
        hf_tools.append(
1✔
109
            ChatCompletionInputTool(function=ChatCompletionInputFunctionDefinition(**hf_tools_args), type="function")
110
        )
111

112
    return hf_tools
1✔
113

114

115
def _map_hf_finish_reason_to_haystack(choice: "ChatCompletionStreamOutputChoice") -> Optional[FinishReason]:
1✔
116
    """
117
    Map HuggingFace finish reasons to Haystack FinishReason literals.
118

119
    Uses the full choice object to detect tool calls and provide accurate mapping.
120

121
    HuggingFace finish reasons (can be found here https://huggingface.github.io/text-generation-inference/ under
122
    FinishReason):
123
    - "length": number of generated tokens == `max_new_tokens`
124
    - "eos_token": the model generated its end of sequence token
125
    - "stop_sequence": the model generated a text included in `stop_sequences`
126

127
    Additionally detects tool calls from delta.tool_calls or delta.tool_call_id.
128

129
    :param choice: The HuggingFace ChatCompletionStreamOutputChoice object.
130
    :returns: The corresponding Haystack FinishReason or None.
131
    """
132
    if choice.finish_reason is None:
1✔
133
        return None
×
134

135
    # Check if this choice contains tool call information
136
    has_tool_calls = choice.delta.tool_calls is not None or choice.delta.tool_call_id is not None
1✔
137

138
    # If we detect tool calls, override the finish reason
139
    if has_tool_calls:
1✔
140
        return "tool_calls"
×
141

142
    # Map HuggingFace finish reasons to Haystack standard ones
143
    mapping: Dict[str, FinishReason] = {
1✔
144
        "length": "length",  # Direct match
145
        "eos_token": "stop",  # EOS token means natural stop
146
        "stop_sequence": "stop",  # Stop sequence means natural stop
147
    }
148

149
    return mapping.get(choice.finish_reason, "stop")  # Default to "stop" for unknown reasons
1✔
150

151

152
def _convert_chat_completion_stream_output_to_streaming_chunk(
1✔
153
    chunk: "ChatCompletionStreamOutput",
154
    previous_chunks: List[StreamingChunk],
155
    component_info: Optional[ComponentInfo] = None,
156
) -> StreamingChunk:
157
    """
158
    Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk.
159
    """
160
    # Choices is empty if include_usage is set to True where the usage information is returned.
161
    if len(chunk.choices) == 0:
1✔
162
        usage = None
1✔
163
        if chunk.usage:
1✔
164
            usage = {"prompt_tokens": chunk.usage.prompt_tokens, "completion_tokens": chunk.usage.completion_tokens}
1✔
165
        return StreamingChunk(
1✔
166
            content="",
167
            meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "usage": usage},
168
            component_info=component_info,
169
        )
170

171
    # n is unused, so the API always returns only one choice
172
    # the argument is probably allowed for compatibility with OpenAI
173
    # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
174
    choice = chunk.choices[0]
1✔
175
    mapped_finish_reason = _map_hf_finish_reason_to_haystack(choice) if choice.finish_reason else None
1✔
176
    stream_chunk = StreamingChunk(
1✔
177
        content=choice.delta.content or "",
178
        meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason},
179
        component_info=component_info,
180
        # Index must always be 0 since we don't allow tool calls in streaming mode.
181
        index=0 if choice.finish_reason is None else None,
182
        # start is True at the very beginning since first chunk contains role information + first part of the answer.
183
        start=len(previous_chunks) == 0,
184
        finish_reason=mapped_finish_reason,
185
    )
186
    return stream_chunk
1✔
187

188

189
@component
1✔
190
class HuggingFaceAPIChatGenerator:
1✔
191
    """
192
    Completes chats using Hugging Face APIs.
193

194
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
195
    format for input and output. Use it to generate text with Hugging Face APIs:
196
    - [Serverless Inference API (Inference Providers)](https://huggingface.co/docs/inference-providers)
197
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
198
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
199

200
    ### Usage examples
201

202
    #### With the serverless inference API (Inference Providers) - free tier available
203

204
    ```python
205
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
206
    from haystack.dataclasses import ChatMessage
207
    from haystack.utils import Secret
208
    from haystack.utils.hf import HFGenerationAPIType
209

210
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
211
                ChatMessage.from_user("What's Natural Language Processing?")]
212

213
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
214
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
215
    api_type = "serverless_inference_api" # this is equivalent to the above
216

217
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
218
                                            api_params={"model": "microsoft/Phi-3.5-mini-instruct",
219
                                                        "provider": "featherless-ai"},
220
                                            token=Secret.from_token("<your-api-key>"))
221

222
    result = generator.run(messages)
223
    print(result)
224
    ```
225

226
    #### With paid inference endpoints
227

228
    ```python
229
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
230
    from haystack.dataclasses import ChatMessage
231
    from haystack.utils import Secret
232

233
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
234
                ChatMessage.from_user("What's Natural Language Processing?")]
235

236
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
237
                                            api_params={"url": "<your-inference-endpoint-url>"},
238
                                            token=Secret.from_token("<your-api-key>"))
239

240
    result = generator.run(messages)
241
    print(result)
242

243
    #### With self-hosted text generation inference
244

245
    ```python
246
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
247
    from haystack.dataclasses import ChatMessage
248

249
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
250
                ChatMessage.from_user("What's Natural Language Processing?")]
251

252
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
253
                                            api_params={"url": "http://localhost:8080"})
254

255
    result = generator.run(messages)
256
    print(result)
257
    ```
258
    """
259

260
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
261
        self,
262
        api_type: Union[HFGenerationAPIType, str],
263
        api_params: Dict[str, str],
264
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
265
        generation_kwargs: Optional[Dict[str, Any]] = None,
266
        stop_words: Optional[List[str]] = None,
267
        streaming_callback: Optional[StreamingCallbackT] = None,
268
        tools: Optional[Union[List[Tool], Toolset]] = None,
269
    ):
270
        """
271
        Initialize the HuggingFaceAPIChatGenerator instance.
272

273
        :param api_type:
274
            The type of Hugging Face API to use. Available types:
275
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
276
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
277
            - `serverless_inference_api`: See
278
            [Serverless Inference API - Inference Providers](https://huggingface.co/docs/inference-providers).
279
        :param api_params:
280
            A dictionary with the following keys:
281
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
282
            - `provider`: Provider name. Recommended when `api_type` is `SERVERLESS_INFERENCE_API`.
283
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
284
            `TEXT_GENERATION_INFERENCE`.
285
            - Other parameters specific to the chosen API type, such as `timeout`, `headers`, etc.
286
        :param token:
287
            The Hugging Face token to use as HTTP bearer authorization.
288
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
289
        :param generation_kwargs:
290
            A dictionary with keyword arguments to customize text generation.
291
                Some examples: `max_tokens`, `temperature`, `top_p`.
292
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
293
        :param stop_words:
294
            An optional list of strings representing the stop words.
295
        :param streaming_callback:
296
            An optional callable for handling streaming responses.
297
        :param tools:
298
            A list of tools or a Toolset for which the model can prepare calls.
299
            The chosen model should support tool/function calling, according to the model card.
300
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
301
            unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
302
        """
303

304
        huggingface_hub_import.check()
1✔
305

306
        if isinstance(api_type, str):
1✔
307
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
308

309
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
310
            model = api_params.get("model")
1✔
311
            if model is None:
1✔
312
                raise ValueError(
1✔
313
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
314
                )
315
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
316
            model_or_url = model
1✔
317
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
318
            url = api_params.get("url")
1✔
319
            if url is None:
1✔
320
                msg = (
1✔
321
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
322
                    "in `api_params`."
323
                )
324
                raise ValueError(msg)
1✔
325
            if not is_valid_http_url(url):
1✔
326
                raise ValueError(f"Invalid URL: {url}")
1✔
327
            model_or_url = url
1✔
328
        else:
329
            msg = f"Unknown api_type {api_type}"
×
330
            raise ValueError(msg)
×
331

332
        if tools and streaming_callback is not None:
1✔
333
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
334
        _check_duplicate_tool_names(list(tools or []))
1✔
335

336
        # handle generation kwargs setup
337
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
338
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
339
        generation_kwargs["stop"].extend(stop_words or [])
1✔
340
        generation_kwargs.setdefault("max_tokens", 512)
1✔
341

342
        self.api_type = api_type
1✔
343
        self.api_params = api_params
1✔
344
        self.token = token
1✔
345
        self.generation_kwargs = generation_kwargs
1✔
346
        self.streaming_callback = streaming_callback
1✔
347

348
        resolved_api_params: Dict[str, Any] = {k: v for k, v in api_params.items() if k != "model" and k != "url"}
1✔
349
        self._client = InferenceClient(
1✔
350
            model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
351
        )
352
        self._async_client = AsyncInferenceClient(
1✔
353
            model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
354
        )
355
        self.tools = tools
1✔
356

357
    def to_dict(self) -> Dict[str, Any]:
1✔
358
        """
359
        Serialize this component to a dictionary.
360

361
        :returns:
362
            A dictionary containing the serialized component.
363
        """
364
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
365
        return default_to_dict(
1✔
366
            self,
367
            api_type=str(self.api_type),
368
            api_params=self.api_params,
369
            token=self.token.to_dict() if self.token else None,
370
            generation_kwargs=self.generation_kwargs,
371
            streaming_callback=callback_name,
372
            tools=serialize_tools_or_toolset(self.tools),
373
        )
374

375
    @classmethod
1✔
376
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
377
        """
378
        Deserialize this component from a dictionary.
379
        """
380
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
381
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
382
        init_params = data.get("init_parameters", {})
1✔
383
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
384
        if serialized_callback_handler:
1✔
385
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
386
        return default_from_dict(cls, data)
1✔
387

388
    @component.output_types(replies=List[ChatMessage])
1✔
389
    def run(
1✔
390
        self,
391
        messages: List[ChatMessage],
392
        generation_kwargs: Optional[Dict[str, Any]] = None,
393
        tools: Optional[Union[List[Tool], Toolset]] = None,
394
        streaming_callback: Optional[StreamingCallbackT] = None,
395
    ):
396
        """
397
        Invoke the text generation inference based on the provided messages and generation parameters.
398

399
        :param messages:
400
            A list of ChatMessage objects representing the input messages.
401
        :param generation_kwargs:
402
            Additional keyword arguments for text generation.
403
        :param tools:
404
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override
405
            the `tools` parameter set during component initialization. This parameter can accept either a
406
            list of `Tool` objects or a `Toolset` instance.
407
        :param streaming_callback:
408
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
409
            parameter set during component initialization.
410
        :returns: A dictionary with the following keys:
411
            - `replies`: A list containing the generated responses as ChatMessage objects.
412
        """
413

414
        # update generation kwargs by merging with the default ones
415
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
416

417
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
418

419
        tools = tools or self.tools
1✔
420
        if tools and self.streaming_callback:
1✔
421
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
422
        _check_duplicate_tool_names(list(tools or []))
1✔
423

424
        # validate and select the streaming callback
425
        streaming_callback = select_streaming_callback(
1✔
426
            self.streaming_callback, streaming_callback, requires_async=False
427
        )
428

429
        if streaming_callback:
1✔
430
            return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
1✔
431

432
        if tools and isinstance(tools, Toolset):
1✔
433
            tools = list(tools)
×
434

435
        hf_tools = _convert_tools_to_hfapi_tools(tools)
1✔
436

437
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
438

439
    @component.output_types(replies=List[ChatMessage])
1✔
440
    async def run_async(
1✔
441
        self,
442
        messages: List[ChatMessage],
443
        generation_kwargs: Optional[Dict[str, Any]] = None,
444
        tools: Optional[Union[List[Tool], Toolset]] = None,
445
        streaming_callback: Optional[StreamingCallbackT] = None,
446
    ):
447
        """
448
        Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
449

450
        This is the asynchronous version of the `run` method. It has the same parameters
451
        and return values but can be used with `await` in an async code.
452

453
        :param messages:
454
            A list of ChatMessage objects representing the input messages.
455
        :param generation_kwargs:
456
            Additional keyword arguments for text generation.
457
        :param tools:
458
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
459
            parameter set during component initialization. This parameter can accept either a list of `Tool` objects
460
            or a `Toolset` instance.
461
        :param streaming_callback:
462
            An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
463
            parameter set during component initialization.
464
        :returns: A dictionary with the following keys:
465
            - `replies`: A list containing the generated responses as ChatMessage objects.
466
        """
467

468
        # update generation kwargs by merging with the default ones
469
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
470

471
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
472

473
        tools = tools or self.tools
1✔
474
        if tools and self.streaming_callback:
1✔
475
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
476
        _check_duplicate_tool_names(list(tools or []))
1✔
477

478
        # validate and select the streaming callback
479
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
480

481
        if streaming_callback:
1✔
482
            return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
1✔
483

484
        if tools and isinstance(tools, Toolset):
1✔
485
            tools = list(tools)
×
486

487
        hf_tools = _convert_tools_to_hfapi_tools(tools)
1✔
488

489
        return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
1✔
490

491
    def _run_streaming(
1✔
492
        self,
493
        messages: List[Dict[str, str]],
494
        generation_kwargs: Dict[str, Any],
495
        streaming_callback: SyncStreamingCallbackT,
496
    ):
497
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
498
            messages,
499
            stream=True,
500
            stream_options=ChatCompletionInputStreamOptions(include_usage=True),
501
            **generation_kwargs,
502
        )
503

504
        component_info = ComponentInfo.from_component(self)
1✔
505
        streaming_chunks: List[StreamingChunk] = []
1✔
506
        for chunk in api_output:
1✔
507
            streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
1✔
508
                chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
509
            )
510
            streaming_chunks.append(streaming_chunk)
1✔
511
            streaming_callback(streaming_chunk)
1✔
512

513
        message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
1✔
514
        if message.meta.get("usage") is None:
1✔
515
            message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
516

517
        return {"replies": [message]}
1✔
518

519
    def _run_non_streaming(
1✔
520
        self,
521
        messages: List[Dict[str, str]],
522
        generation_kwargs: Dict[str, Any],
523
        tools: Optional[List["ChatCompletionInputTool"]] = None,
524
    ) -> Dict[str, List[ChatMessage]]:
525
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
526
            messages=messages, tools=tools, **generation_kwargs
527
        )
528

529
        if len(api_chat_output.choices) == 0:
1✔
530
            return {"replies": []}
×
531

532
        # n is unused, so the API always returns only one choice
533
        # the argument is probably allowed for compatibility with OpenAI
534
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
535
        choice = api_chat_output.choices[0]
1✔
536

537
        text = choice.message.content
1✔
538

539
        tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
1✔
540

541
        meta: Dict[str, Any] = {
1✔
542
            "model": self._client.model,
543
            "finish_reason": choice.finish_reason,
544
            "index": choice.index,
545
        }
546

547
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
548
        if api_chat_output.usage:
1✔
549
            usage = {
1✔
550
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
551
                "completion_tokens": api_chat_output.usage.completion_tokens,
552
            }
553
        meta["usage"] = usage
1✔
554

555
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
556
        return {"replies": [message]}
1✔
557

558
    async def _run_streaming_async(
1✔
559
        self,
560
        messages: List[Dict[str, str]],
561
        generation_kwargs: Dict[str, Any],
562
        streaming_callback: AsyncStreamingCallbackT,
563
    ):
564
        api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
1✔
565
            messages,
566
            stream=True,
567
            stream_options=ChatCompletionInputStreamOptions(include_usage=True),
568
            **generation_kwargs,
569
        )
570

571
        component_info = ComponentInfo.from_component(self)
1✔
572
        streaming_chunks: List[StreamingChunk] = []
1✔
573
        async for chunk in api_output:
1✔
574
            stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(
1✔
575
                chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info
576
            )
577
            streaming_chunks.append(stream_chunk)
1✔
578
            await streaming_callback(stream_chunk)  # type: ignore
1✔
579

580
        message = _convert_streaming_chunks_to_chat_message(chunks=streaming_chunks)
1✔
581
        if message.meta.get("usage") is None:
1✔
582
            message.meta["usage"] = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
583

584
        return {"replies": [message]}
1✔
585

586
    async def _run_non_streaming_async(
1✔
587
        self,
588
        messages: List[Dict[str, str]],
589
        generation_kwargs: Dict[str, Any],
590
        tools: Optional[List["ChatCompletionInputTool"]] = None,
591
    ) -> Dict[str, List[ChatMessage]]:
592
        api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
1✔
593
            messages=messages, tools=tools, **generation_kwargs
594
        )
595

596
        if len(api_chat_output.choices) == 0:
1✔
597
            return {"replies": []}
×
598

599
        choice = api_chat_output.choices[0]
1✔
600

601
        text = choice.message.content
1✔
602

603
        tool_calls = _convert_hfapi_tool_calls(choice.message.tool_calls)
1✔
604

605
        meta: Dict[str, Any] = {
1✔
606
            "model": self._async_client.model,
607
            "finish_reason": choice.finish_reason,
608
            "index": choice.index,
609
        }
610

611
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
612
        if api_chat_output.usage:
1✔
613
            usage = {
1✔
614
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
615
                "completion_tokens": api_chat_output.usage.completion_tokens,
616
            }
617
        meta["usage"] = usage
1✔
618

619
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
620
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc