• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12809254457

16 Jan 2025 12:46PM UTC coverage: 91.297% (+0.005%) from 91.292%
12809254457

Pull #8728

github

web-flow
Merge a6afceb1d into 62ac27c94
Pull Request #8728: feat: Add completion start time timestamp to relevant chat generators

8843 of 9686 relevant lines covered (91.3%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.46
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from datetime import datetime
1✔
6
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging
1✔
9
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
12
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
13
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1✔
14
from haystack.utils.url_validation import is_valid_http_url
1✔
15

16
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1✔
17
    from huggingface_hub import (
1✔
18
        ChatCompletionInputTool,
19
        ChatCompletionOutput,
20
        ChatCompletionStreamOutput,
21
        InferenceClient,
22
    )
23

24

25
logger = logging.getLogger(__name__)
1✔
26

27

28
@component
1✔
29
class HuggingFaceAPIChatGenerator:
1✔
30
    """
31
    Completes chats using Hugging Face APIs.
32

33
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
34
    format for input and output. Use it to generate text with Hugging Face APIs:
35
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
36
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
37
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
38

39
    ### Usage examples
40

41
    #### With the free serverless inference API
42

43
    ```python
44
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
45
    from haystack.dataclasses import ChatMessage
46
    from haystack.utils import Secret
47
    from haystack.utils.hf import HFGenerationAPIType
48

49
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
50
                ChatMessage.from_user("What's Natural Language Processing?")]
51

52
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
53
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
54
    api_type = "serverless_inference_api" # this is equivalent to the above
55

56
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
57
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
58
                                            token=Secret.from_token("<your-api-key>"))
59

60
    result = generator.run(messages)
61
    print(result)
62
    ```
63

64
    #### With paid inference endpoints
65

66
    ```python
67
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
68
    from haystack.dataclasses import ChatMessage
69
    from haystack.utils import Secret
70

71
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
72
                ChatMessage.from_user("What's Natural Language Processing?")]
73

74
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
75
                                            api_params={"url": "<your-inference-endpoint-url>"},
76
                                            token=Secret.from_token("<your-api-key>"))
77

78
    result = generator.run(messages)
79
    print(result)
80

81
    #### With self-hosted text generation inference
82

83
    ```python
84
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
85
    from haystack.dataclasses import ChatMessage
86

87
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
88
                ChatMessage.from_user("What's Natural Language Processing?")]
89

90
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
91
                                            api_params={"url": "http://localhost:8080"})
92

93
    result = generator.run(messages)
94
    print(result)
95
    ```
96
    """
97

98
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
99
        self,
100
        api_type: Union[HFGenerationAPIType, str],
101
        api_params: Dict[str, str],
102
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
103
        generation_kwargs: Optional[Dict[str, Any]] = None,
104
        stop_words: Optional[List[str]] = None,
105
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
106
        tools: Optional[List[Tool]] = None,
107
    ):
108
        """
109
        Initialize the HuggingFaceAPIChatGenerator instance.
110

111
        :param api_type:
112
            The type of Hugging Face API to use. Available types:
113
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
114
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
115
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
116
        :param api_params:
117
            A dictionary with the following keys:
118
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
119
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
120
            `TEXT_GENERATION_INFERENCE`.
121
        :param token:
122
            The Hugging Face token to use as HTTP bearer authorization.
123
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
124
        :param generation_kwargs:
125
            A dictionary with keyword arguments to customize text generation.
126
                Some examples: `max_tokens`, `temperature`, `top_p`.
127
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
128
        :param stop_words:
129
            An optional list of strings representing the stop words.
130
        :param streaming_callback:
131
            An optional callable for handling streaming responses.
132
        :param tools:
133
            A list of tools for which the model can prepare calls.
134
            The chosen model should support tool/function calling, according to the model card.
135
            Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
136
            unexpected behavior.
137
        """
138

139
        huggingface_hub_import.check()
1✔
140

141
        if isinstance(api_type, str):
1✔
142
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
143

144
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
145
            model = api_params.get("model")
1✔
146
            if model is None:
1✔
147
                raise ValueError(
1✔
148
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
149
                )
150
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
151
            model_or_url = model
1✔
152
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
153
            url = api_params.get("url")
1✔
154
            if url is None:
1✔
155
                msg = (
1✔
156
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
157
                    "in `api_params`."
158
                )
159
                raise ValueError(msg)
1✔
160
            if not is_valid_http_url(url):
1✔
161
                raise ValueError(f"Invalid URL: {url}")
1✔
162
            model_or_url = url
1✔
163
        else:
164
            msg = f"Unknown api_type {api_type}"
×
165
            raise ValueError(msg)
×
166

167
        if tools and streaming_callback is not None:
1✔
168
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
169
        _check_duplicate_tool_names(tools)
1✔
170

171
        # handle generation kwargs setup
172
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
173
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
174
        generation_kwargs["stop"].extend(stop_words or [])
1✔
175
        generation_kwargs.setdefault("max_tokens", 512)
1✔
176

177
        self.api_type = api_type
1✔
178
        self.api_params = api_params
1✔
179
        self.token = token
1✔
180
        self.generation_kwargs = generation_kwargs
1✔
181
        self.streaming_callback = streaming_callback
1✔
182
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
183
        self.tools = tools
1✔
184

185
    def to_dict(self) -> Dict[str, Any]:
1✔
186
        """
187
        Serialize this component to a dictionary.
188

189
        :returns:
190
            A dictionary containing the serialized component.
191
        """
192
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
193
        serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
1✔
194
        return default_to_dict(
1✔
195
            self,
196
            api_type=str(self.api_type),
197
            api_params=self.api_params,
198
            token=self.token.to_dict() if self.token else None,
199
            generation_kwargs=self.generation_kwargs,
200
            streaming_callback=callback_name,
201
            tools=serialized_tools,
202
        )
203

204
    @classmethod
1✔
205
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
206
        """
207
        Deserialize this component from a dictionary.
208
        """
209
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
210
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
211
        init_params = data.get("init_parameters", {})
1✔
212
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
213
        if serialized_callback_handler:
1✔
214
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
215
        return default_from_dict(cls, data)
1✔
216

217
    @component.output_types(replies=List[ChatMessage])
1✔
218
    def run(
1✔
219
        self,
220
        messages: List[ChatMessage],
221
        generation_kwargs: Optional[Dict[str, Any]] = None,
222
        tools: Optional[List[Tool]] = None,
223
    ):
224
        """
225
        Invoke the text generation inference based on the provided messages and generation parameters.
226

227
        :param messages:
228
            A list of ChatMessage objects representing the input messages.
229
        :param generation_kwargs:
230
            Additional keyword arguments for text generation.
231
        :param tools:
232
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
233
            during component initialization.
234
        :returns: A dictionary with the following keys:
235
            - `replies`: A list containing the generated responses as ChatMessage objects.
236
        """
237

238
        # update generation kwargs by merging with the default ones
239
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
240

241
        formatted_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
242

243
        tools = tools or self.tools
1✔
244
        if tools and self.streaming_callback:
1✔
245
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
246
        _check_duplicate_tool_names(tools)
1✔
247

248
        if self.streaming_callback:
1✔
249
            return self._run_streaming(formatted_messages, generation_kwargs)
1✔
250

251
        hf_tools = None
1✔
252
        if tools:
1✔
253
            hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
1✔
254

255
        return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
1✔
256

257
    def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
1✔
258
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
259
            messages, stream=True, **generation_kwargs
260
        )
261

262
        generated_text = ""
1✔
263
        chunks_meta = []
1✔
264
        first_chunk_time = None
1✔
265

266
        for chunk in api_output:
1✔
267
            # n is unused, so the API always returns only one choice
268
            # the argument is probably allowed for compatibility with OpenAI
269
            # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
270
            choice = chunk.choices[0]
1✔
271

272
            text = choice.delta.content
1✔
273
            if text:
1✔
274
                generated_text += text
1✔
275

276
            finish_reason = choice.finish_reason
1✔
277

278
            meta = {
1✔
279
                "finish_reason": finish_reason if finish_reason else None,
280
                "received_at": datetime.now().isoformat(),
281
            }
282
            chunks_meta.append(meta)
1✔
283

284
            if first_chunk_time is None:
1✔
285
                first_chunk_time = meta["received_at"]
1✔
286

287
            stream_chunk = StreamingChunk(text, meta)
1✔
288
            self.streaming_callback(stream_chunk)  # type: ignore # streaming_callback is not None (verified in the run method)
1✔
289

290
        meta.update(
1✔
291
            {
292
                "model": self._client.model,
293
                "finish_reason": finish_reason,
294
                "index": 0,
295
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
296
                "completion_start_time": first_chunk_time,
297
            }
298
        )
299

300
        message = ChatMessage.from_assistant(text=generated_text, meta=meta)
1✔
301

302
        return {"replies": [message]}
1✔
303

304
    def _run_non_streaming(
1✔
305
        self,
306
        messages: List[Dict[str, str]],
307
        generation_kwargs: Dict[str, Any],
308
        tools: Optional[List["ChatCompletionInputTool"]] = None,
309
    ) -> Dict[str, List[ChatMessage]]:
310
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(
1✔
311
            messages=messages, tools=tools, **generation_kwargs
312
        )
313

314
        if len(api_chat_output.choices) == 0:
1✔
315
            return {"replies": []}
×
316

317
        # n is unused, so the API always returns only one choice
318
        # the argument is probably allowed for compatibility with OpenAI
319
        # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
320
        choice = api_chat_output.choices[0]
1✔
321

322
        text = choice.message.content
1✔
323
        tool_calls = []
1✔
324

325
        if hfapi_tool_calls := choice.message.tool_calls:
1✔
326
            for hfapi_tc in hfapi_tool_calls:
1✔
327
                tool_call = ToolCall(
1✔
328
                    tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
329
                )
330
                tool_calls.append(tool_call)
1✔
331

332
        meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
1✔
333

334
        usage = {"prompt_tokens": 0, "completion_tokens": 0}
1✔
335
        if api_chat_output.usage:
1✔
336
            usage = {
1✔
337
                "prompt_tokens": api_chat_output.usage.prompt_tokens,
338
                "completion_tokens": api_chat_output.usage.completion_tokens,
339
            }
340
        meta["usage"] = usage
1✔
341

342
        message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
343
        return {"replies": [message]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc