• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13196574846

07 Feb 2025 09:04AM UTC coverage: 92.158% (+0.9%) from 91.299%
13196574846

Pull #8817

github

web-flow
Merge bfbca25b1 into 1785ea622
Pull Request #8817: fix: Update OpenAPIServiceConnector to new ChatMessage

9025 of 9793 relevant lines covered (92.16%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.3
haystack/components/generators/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import asdict
1✔
6
from datetime import datetime
1✔
7
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging
1✔
10
from haystack.dataclasses import StreamingChunk
1✔
11
from haystack.lazy_imports import LazyImport
1✔
12
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
13
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
1✔
14
from haystack.utils.url_validation import is_valid_http_url
1✔
15

16
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1✔
17
    from huggingface_hub import (
1✔
18
        InferenceClient,
19
        TextGenerationOutput,
20
        TextGenerationStreamOutput,
21
        TextGenerationStreamOutputToken,
22
    )
23

24

25
logger = logging.getLogger(__name__)
1✔
26

27

28
@component
1✔
29
class HuggingFaceAPIGenerator:
1✔
30
    """
31
    Generates text using Hugging Face APIs.
32

33
    Use it with the following Hugging Face APIs:
34
    - [Free Serverless Inference API]((https://huggingface.co/inference-api)
35
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
36
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
37

38
    ### Usage examples
39

40
    #### With the free serverless inference API
41

42
    ```python
43
    from haystack.components.generators import HuggingFaceAPIGenerator
44
    from haystack.utils import Secret
45

46
    generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
47
                                        api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
48
                                        token=Secret.from_token("<your-api-key>"))
49

50
    result = generator.run(prompt="What's Natural Language Processing?")
51
    print(result)
52
    ```
53

54
    #### With paid inference endpoints
55

56
    ```python
57
    from haystack.components.generators import HuggingFaceAPIGenerator
58
    from haystack.utils import Secret
59

60
    generator = HuggingFaceAPIGenerator(api_type="inference_endpoints",
61
                                        api_params={"url": "<your-inference-endpoint-url>"},
62
                                        token=Secret.from_token("<your-api-key>"))
63

64
    result = generator.run(prompt="What's Natural Language Processing?")
65
    print(result)
66

67
    #### With self-hosted text generation inference
68
    ```python
69
    from haystack.components.generators import HuggingFaceAPIGenerator
70

71
    generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
72
                                        api_params={"url": "http://localhost:8080"})
73

74
    result = generator.run(prompt="What's Natural Language Processing?")
75
    print(result)
76
    ```
77
    """
78

79
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
80
        self,
81
        api_type: Union[HFGenerationAPIType, str],
82
        api_params: Dict[str, str],
83
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
84
        generation_kwargs: Optional[Dict[str, Any]] = None,
85
        stop_words: Optional[List[str]] = None,
86
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
87
    ):
88
        """
89
        Initialize the HuggingFaceAPIGenerator instance.
90

91
        :param api_type:
92
            The type of Hugging Face API to use. Available types:
93
            - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
94
            - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
95
            - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
96
        :param api_params:
97
            A dictionary with the following keys:
98
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
99
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
100
            `TEXT_GENERATION_INFERENCE`.
101
        :param token: The Hugging Face token to use as HTTP bearer authorization.
102
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
103
        :param generation_kwargs:
104
            A dictionary with keyword arguments to customize text generation. Some examples: `max_new_tokens`,
105
            `temperature`, `top_k`, `top_p`.
106
            For details, see [Hugging Face documentation](https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation)
107
            for more information.
108
        :param stop_words: An optional list of strings representing the stop words.
109
        :param streaming_callback: An optional callable for handling streaming responses.
110
        """
111

112
        huggingface_hub_import.check()
1✔
113

114
        if isinstance(api_type, str):
1✔
115
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
116

117
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
118
            model = api_params.get("model")
1✔
119
            if model is None:
1✔
120
                raise ValueError(
1✔
121
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
122
                )
123
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
124
            model_or_url = model
1✔
125
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
126
            url = api_params.get("url")
1✔
127
            if url is None:
1✔
128
                msg = (
1✔
129
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` "
130
                    "parameter in `api_params`."
131
                )
132
                raise ValueError(msg)
1✔
133
            if not is_valid_http_url(url):
1✔
134
                raise ValueError(f"Invalid URL: {url}")
1✔
135
            model_or_url = url
1✔
136
        else:
137
            msg = f"Unknown api_type {api_type}"
×
138
            raise ValueError(msg)
×
139

140
        # handle generation kwargs setup
141
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
142
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
143
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
144
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
145

146
        self.api_type = api_type
1✔
147
        self.api_params = api_params
1✔
148
        self.token = token
1✔
149
        self.generation_kwargs = generation_kwargs
1✔
150
        self.streaming_callback = streaming_callback
1✔
151
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
152

153
    def to_dict(self) -> Dict[str, Any]:
1✔
154
        """
155
        Serialize this component to a dictionary.
156

157
        :returns:
158
            A dictionary containing the serialized component.
159
        """
160
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
161
        return default_to_dict(
1✔
162
            self,
163
            api_type=str(self.api_type),
164
            api_params=self.api_params,
165
            token=self.token.to_dict() if self.token else None,
166
            generation_kwargs=self.generation_kwargs,
167
            streaming_callback=callback_name,
168
        )
169

170
    @classmethod
1✔
171
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIGenerator":
1✔
172
        """
173
        Deserialize this component from a dictionary.
174
        """
175
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
176
        init_params = data["init_parameters"]
1✔
177
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
178
        if serialized_callback_handler:
1✔
179
            init_params["streaming_callback"] = deserialize_callable(serialized_callback_handler)
1✔
180
        return default_from_dict(cls, data)
1✔
181

182
    @component.output_types(replies=List[str], meta=List[Dict[str, Any]])
1✔
183
    def run(
1✔
184
        self,
185
        prompt: str,
186
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
187
        generation_kwargs: Optional[Dict[str, Any]] = None,
188
    ):
189
        """
190
        Invoke the text generation inference for the given prompt and generation parameters.
191

192
        :param prompt:
193
            A string representing the prompt.
194
        :param streaming_callback:
195
            A callback function that is called when a new token is received from the stream.
196
        :param generation_kwargs:
197
            Additional keyword arguments for text generation.
198
        :returns:
199
            A dictionary with the generated replies and metadata. Both are lists of length n.
200
            - replies: A list of strings representing the generated replies.
201
        """
202
        # update generation kwargs by merging with the default ones
203
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
204

205
        # check if streaming_callback is passed
206
        streaming_callback = streaming_callback or self.streaming_callback
1✔
207

208
        hf_output = self._client.text_generation(
1✔
209
            prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
210
        )
211

212
        if streaming_callback is not None:
1✔
213
            return self._stream_and_build_response(hf_output, streaming_callback)
1✔
214

215
        # mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
216
        return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
1✔
217

218
    def _stream_and_build_response(
1✔
219
        self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
220
    ):
221
        chunks: List[StreamingChunk] = []
1✔
222
        first_chunk_time = None
1✔
223

224
        for chunk in hf_output:
1✔
225
            token: TextGenerationStreamOutputToken = chunk.token
1✔
226
            if token.special:
1✔
227
                continue
×
228

229
            chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
1✔
230
            if first_chunk_time is None:
1✔
231
                first_chunk_time = datetime.now().isoformat()
1✔
232

233
            stream_chunk = StreamingChunk(token.text, chunk_metadata)
1✔
234
            chunks.append(stream_chunk)
1✔
235
            streaming_callback(stream_chunk)
1✔
236

237
        metadata = {
1✔
238
            "finish_reason": chunks[-1].meta.get("finish_reason", None),
239
            "model": self._client.model,
240
            "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)},
241
            "completion_start_time": first_chunk_time,
242
        }
243
        return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
1✔
244

245
    def _build_non_streaming_response(self, hf_output: "TextGenerationOutput"):
1✔
246
        meta = [
1✔
247
            {
248
                "model": self._client.model,
249
                "finish_reason": hf_output.details.finish_reason if hf_output.details else None,
250
                "usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0},
251
            }
252
        ]
253
        return {"replies": [hf_output.generated_text], "meta": meta}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc