• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13076938430

31 Jan 2025 04:39PM UTC coverage: 91.364% (+0.005%) from 91.359%
13076938430

Pull #8794

github

web-flow
Merge 3de090102 into 379711f63
Pull Request #8794: refactor: HF API Embedders refactoring

8876 of 9715 relevant lines covered (91.36%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.61
haystack/components/embedders/hugging_face_api_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
11
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
1✔
12
from haystack.utils.url_validation import is_valid_http_url
1✔
13

14
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1✔
15
    from huggingface_hub import InferenceClient
1✔
16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
@component
1✔
21
class HuggingFaceAPITextEmbedder:
1✔
22
    """
23
    Embeds strings using Hugging Face APIs.
24

25
    Use it with the following Hugging Face APIs:
26
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
27
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
28
    - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
29

30
    ### Usage examples
31

32
    #### With free serverless inference API
33

34
    ```python
35
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
36
    from haystack.utils import Secret
37

38
    text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
39
                                               api_params={"model": "BAAI/bge-small-en-v1.5"},
40
                                               token=Secret.from_token("<your-api-key>"))
41

42
    print(text_embedder.run("I love pizza!"))
43

44
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
45
    ```
46

47
    #### With paid inference endpoints
48

49
    ```python
50
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
51
    from haystack.utils import Secret
52
    text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
53
                                               api_params={"model": "BAAI/bge-small-en-v1.5"},
54
                                               token=Secret.from_token("<your-api-key>"))
55

56
    print(text_embedder.run("I love pizza!"))
57

58
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
59
    ```
60

61
    #### With self-hosted text embeddings inference
62

63
    ```python
64
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
65
    from haystack.utils import Secret
66

67
    text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
68
                                               api_params={"url": "http://localhost:8080"})
69

70
    print(text_embedder.run("I love pizza!"))
71

72
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
73
    ```
74
    """
75

76
    def __init__(
1✔
77
        self,
78
        api_type: Union[HFEmbeddingAPIType, str],
79
        api_params: Dict[str, str],
80
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
81
        prefix: str = "",
82
        suffix: str = "",
83
        truncate: bool = True,
84
        normalize: bool = False,
85
    ):  # pylint: disable=too-many-positional-arguments
86
        """
87
        Creates a HuggingFaceAPITextEmbedder component.
88

89
        :param api_type:
90
            The type of Hugging Face API to use.
91
        :param api_params:
92
            A dictionary with the following keys:
93
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
94
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
95
            `TEXT_EMBEDDINGS_INFERENCE`.
96
        :param token: The Hugging Face token to use as HTTP bearer authorization.
97
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
98
        :param prefix:
99
            A string to add at the beginning of each text.
100
        :param suffix:
101
            A string to add at the end of each text.
102
        :param truncate:
103
            Truncates the input text to the maximum length supported by the model.
104
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
105
            if the backend uses Text Embeddings Inference.
106
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
107
        :param normalize:
108
            Normalizes the embeddings to unit length.
109
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
110
            if the backend uses Text Embeddings Inference.
111
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
112
        """
113
        huggingface_hub_import.check()
1✔
114

115
        if isinstance(api_type, str):
1✔
116
            api_type = HFEmbeddingAPIType.from_str(api_type)
1✔
117

118
        if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
1✔
119
            model = api_params.get("model")
1✔
120
            if model is None:
1✔
121
                raise ValueError(
1✔
122
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
123
                )
124
            check_valid_model(model, HFModelType.EMBEDDING, token)
1✔
125
            model_or_url = model
1✔
126
        elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
1✔
127
            url = api_params.get("url")
1✔
128
            if url is None:
1✔
129
                msg = (
1✔
130
                    "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
131
                    "parameter in `api_params`."
132
                )
133
                raise ValueError(msg)
1✔
134
            if not is_valid_http_url(url):
1✔
135
                raise ValueError(f"Invalid URL: {url}")
1✔
136
            model_or_url = url
1✔
137
        else:
138
            msg = f"Unknown api_type {api_type}"
×
139
            raise ValueError(msg)
×
140

141
        self.api_type = api_type
1✔
142
        self.api_params = api_params
1✔
143
        self.token = token
1✔
144
        self.prefix = prefix
1✔
145
        self.suffix = suffix
1✔
146
        self.truncate = truncate
1✔
147
        self.normalize = normalize
1✔
148
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
149

150
    def to_dict(self) -> Dict[str, Any]:
1✔
151
        """
152
        Serializes the component to a dictionary.
153

154
        :returns:
155
            Dictionary with serialized data.
156
        """
157
        return default_to_dict(
1✔
158
            self,
159
            api_type=str(self.api_type),
160
            api_params=self.api_params,
161
            prefix=self.prefix,
162
            suffix=self.suffix,
163
            token=self.token.to_dict() if self.token else None,
164
            truncate=self.truncate,
165
            normalize=self.normalize,
166
        )
167

168
    @classmethod
1✔
169
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
1✔
170
        """
171
        Deserializes the component from a dictionary.
172

173
        :param data:
174
            Dictionary to deserialize from.
175
        :returns:
176
            Deserialized component.
177
        """
178
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
179
        return default_from_dict(cls, data)
1✔
180

181
    @component.output_types(embedding=List[float])
1✔
182
    def run(self, text: str):
1✔
183
        """
184
        Embeds a single string.
185

186
        :param text:
187
            Text to embed.
188

189
        :returns:
190
            A dictionary with the following keys:
191
            - `embedding`: The embedding of the input text.
192
        """
193
        if not isinstance(text, str):
1✔
194
            raise TypeError(
1✔
195
                "HuggingFaceAPITextEmbedder expects a string as an input."
196
                "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
197
            )
198

199
        text_to_embed = self.prefix + text + self.suffix
1✔
200

201
        np_embedding = self._client.feature_extraction(
1✔
202
            text=text_to_embed,
203
            # Serverless Inference API does not support truncate and normalize, so we pass None in the request
204
            truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None,
205
            normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None,
206
        )
207

208
        error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
1✔
209
        if np_embedding.ndim > 2:
1✔
210
            raise ValueError(error_msg)
1✔
211
        if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
1✔
212
            raise ValueError(error_msg)
1✔
213

214
        embedding = np_embedding.flatten().tolist()
1✔
215

216
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc