• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12370483819

17 Dec 2024 09:51AM UTC coverage: 90.565% (+0.09%) from 90.48%
12370483819

Pull #8640

github

web-flow
Merge b88daaea0 into a5b57f4b1
Pull Request #8640: feat!: new `ChatMessage`

8207 of 9062 relevant lines covered (90.56%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.67
haystack/components/generators/chat/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
1✔
6

7
from haystack import component, default_from_dict, default_to_dict, logging
1✔
8
from haystack.dataclasses import ChatMessage, StreamingChunk
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
11
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
1✔
12
from haystack.utils.url_validation import is_valid_http_url
1✔
13

14
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
1✔
15
    from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
1✔
16

17

18
logger = logging.getLogger(__name__)
1✔
19

20

21
def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
1✔
22
    """
23
    Convert a message to the format expected by Hugging Face APIs.
24

25
    :returns: A dictionary with the following keys:
26
        - `role`
27
        - `content`
28
    """
29
    return {"role": message.role.value, "content": message.text or ""}
1✔
30

31

32
@component
1✔
33
class HuggingFaceAPIChatGenerator:
1✔
34
    """
35
    Completes chats using Hugging Face APIs.
36

37
    HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
38
    format for input and output. Use it to generate text with Hugging Face APIs:
39
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
40
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
41
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
42

43
    ### Usage examples
44

45
    #### With the free serverless inference API
46

47
    ```python
48
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
49
    from haystack.dataclasses import ChatMessage
50
    from haystack.utils import Secret
51
    from haystack.utils.hf import HFGenerationAPIType
52

53
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
54
                ChatMessage.from_user("What's Natural Language Processing?")]
55

56
    # the api_type can be expressed using the HFGenerationAPIType enum or as a string
57
    api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
58
    api_type = "serverless_inference_api" # this is equivalent to the above
59

60
    generator = HuggingFaceAPIChatGenerator(api_type=api_type,
61
                                            api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
62
                                            token=Secret.from_token("<your-api-key>"))
63

64
    result = generator.run(messages)
65
    print(result)
66
    ```
67

68
    #### With paid inference endpoints
69

70
    ```python
71
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
72
    from haystack.dataclasses import ChatMessage
73
    from haystack.utils import Secret
74

75
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
76
                ChatMessage.from_user("What's Natural Language Processing?")]
77

78
    generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
79
                                            api_params={"url": "<your-inference-endpoint-url>"},
80
                                            token=Secret.from_token("<your-api-key>"))
81

82
    result = generator.run(messages)
83
    print(result)
84

85
    #### With self-hosted text generation inference
86

87
    ```python
88
    from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
89
    from haystack.dataclasses import ChatMessage
90

91
    messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
92
                ChatMessage.from_user("What's Natural Language Processing?")]
93

94
    generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
95
                                            api_params={"url": "http://localhost:8080"})
96

97
    result = generator.run(messages)
98
    print(result)
99
    ```
100
    """
101

102
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
103
        self,
104
        api_type: Union[HFGenerationAPIType, str],
105
        api_params: Dict[str, str],
106
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
107
        generation_kwargs: Optional[Dict[str, Any]] = None,
108
        stop_words: Optional[List[str]] = None,
109
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
110
    ):
111
        """
112
        Initialize the HuggingFaceAPIChatGenerator instance.
113

114
        :param api_type:
115
            The type of Hugging Face API to use. Available types:
116
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
117
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
118
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
119
        :param api_params:
120
            A dictionary with the following keys:
121
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
122
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
123
            `TEXT_GENERATION_INFERENCE`.
124
        :param token: The Hugging Face token to use as HTTP bearer authorization.
125
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
126
        :param generation_kwargs:
127
            A dictionary with keyword arguments to customize text generation.
128
                Some examples: `max_tokens`, `temperature`, `top_p`.
129
                For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
130
        :param stop_words: An optional list of strings representing the stop words.
131
        :param streaming_callback: An optional callable for handling streaming responses.
132
        """
133

134
        huggingface_hub_import.check()
1✔
135

136
        if isinstance(api_type, str):
1✔
137
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
138

139
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
140
            model = api_params.get("model")
1✔
141
            if model is None:
1✔
142
                raise ValueError(
1✔
143
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
144
                )
145
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
146
            model_or_url = model
1✔
147
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
148
            url = api_params.get("url")
1✔
149
            if url is None:
1✔
150
                msg = (
1✔
151
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter "
152
                    "in `api_params`."
153
                )
154
                raise ValueError(msg)
1✔
155
            if not is_valid_http_url(url):
1✔
156
                raise ValueError(f"Invalid URL: {url}")
1✔
157
            model_or_url = url
1✔
158
        else:
159
            msg = f"Unknown api_type {api_type}"
×
160
            raise ValueError(msg)
×
161

162
        # handle generation kwargs setup
163
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
164
        generation_kwargs["stop"] = generation_kwargs.get("stop", [])
1✔
165
        generation_kwargs["stop"].extend(stop_words or [])
1✔
166
        generation_kwargs.setdefault("max_tokens", 512)
1✔
167

168
        self.api_type = api_type
1✔
169
        self.api_params = api_params
1✔
170
        self.token = token
1✔
171
        self.generation_kwargs = generation_kwargs
1✔
172
        self.streaming_callback = streaming_callback
1✔
173
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
174

175
    def to_dict(self) -> Dict[str, Any]:
1✔
176
        """
177
        Serialize this component to a dictionary.
178

179
        :returns:
180
            A dictionary containing the serialized component.
181
        """
182
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
183
        return default_to_dict(
1✔
184
            self,
185
            api_type=str(self.api_type),
186
            api_params=self.api_params,
187
            token=self.token.to_dict() if self.token else None,
188
            generation_kwargs=self.generation_kwargs,
189
            streaming_callback=callback_name,
190
        )
191

192
    @classmethod
1✔
193
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
1✔
194
        """
195
        Deserialize this component from a dictionary.
196
        """
197
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
198
        init_params = data.get("init_parameters", {})
1✔
199
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
200
        if serialized_callback_handler:
1✔
201
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
1✔
202
        return default_from_dict(cls, data)
1✔
203

204
    @component.output_types(replies=List[ChatMessage])
1✔
205
    def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
1✔
206
        """
207
        Invoke the text generation inference based on the provided messages and generation parameters.
208

209
        :param messages: A list of ChatMessage objects representing the input messages.
210
        :param generation_kwargs: Additional keyword arguments for text generation.
211
        :returns: A dictionary with the following keys:
212
            - `replies`: A list containing the generated responses as ChatMessage objects.
213
        """
214

215
        # update generation kwargs by merging with the default ones
216
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
217

218
        formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
1✔
219

220
        if self.streaming_callback:
1✔
221
            return self._run_streaming(formatted_messages, generation_kwargs)
1✔
222

223
        return self._run_non_streaming(formatted_messages, generation_kwargs)
1✔
224

225
    def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
1✔
226
        api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
1✔
227
            messages, stream=True, **generation_kwargs
228
        )
229

230
        generated_text = ""
1✔
231

232
        for chunk in api_output:  # pylint: disable=not-an-iterable
1✔
233
            text = chunk.choices[0].delta.content
1✔
234
            if text:
1✔
235
                generated_text += text
1✔
236
            finish_reason = chunk.choices[0].finish_reason
1✔
237

238
            meta = {}
1✔
239
            if finish_reason:
1✔
240
                meta["finish_reason"] = finish_reason
1✔
241

242
            stream_chunk = StreamingChunk(text, meta)
1✔
243
            self.streaming_callback(stream_chunk)  # type: ignore # streaming_callback is not None (verified in the run method)
1✔
244

245
        message = ChatMessage.from_assistant(generated_text)
1✔
246
        message.meta.update(
1✔
247
            {
248
                "model": self._client.model,
249
                "finish_reason": finish_reason,
250
                "index": 0,
251
                "usage": {"prompt_tokens": 0, "completion_tokens": 0},  # not available in streaming
252
            }
253
        )
254
        return {"replies": [message]}
1✔
255

256
    def _run_non_streaming(
1✔
257
        self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
258
    ) -> Dict[str, List[ChatMessage]]:
259
        chat_messages: List[ChatMessage] = []
1✔
260

261
        api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
1✔
262
        for choice in api_chat_output.choices:
1✔
263
            message = ChatMessage.from_assistant(choice.message.content)
1✔
264
            message.meta.update(
1✔
265
                {
266
                    "model": self._client.model,
267
                    "finish_reason": choice.finish_reason,
268
                    "index": choice.index,
269
                    "usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
270
                }
271
            )
272
            chat_messages.append(message)
1✔
273

274
        return {"replies": chat_messages}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc