• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13077074655

31 Jan 2025 04:47PM UTC coverage: 91.363% (+0.004%) from 91.359%
13077074655

Pull #8794

github

web-flow
Merge 6c32f8328 into 379711f63
Pull Request #8794: refactor: HF API Embedders refactoring

8875 of 9714 relevant lines covered (91.36%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.55
haystack/components/embedders/hugging_face_api_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from haystack import component, default_from_dict, default_to_dict, logging
1✔
8
from haystack.lazy_imports import LazyImport
1✔
9
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
10
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
1✔
11
from haystack.utils.url_validation import is_valid_http_url
1✔
12

13
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1✔
14
    from huggingface_hub import InferenceClient
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
@component
1✔
20
class HuggingFaceAPITextEmbedder:
1✔
21
    """
22
    Embeds strings using Hugging Face APIs.
23

24
    Use it with the following Hugging Face APIs:
25
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
26
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
27
    - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
28

29
    ### Usage examples
30

31
    #### With free serverless inference API
32

33
    ```python
34
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
35
    from haystack.utils import Secret
36

37
    text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
38
                                               api_params={"model": "BAAI/bge-small-en-v1.5"},
39
                                               token=Secret.from_token("<your-api-key>"))
40

41
    print(text_embedder.run("I love pizza!"))
42

43
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
44
    ```
45

46
    #### With paid inference endpoints
47

48
    ```python
49
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
50
    from haystack.utils import Secret
51
    text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
52
                                               api_params={"model": "BAAI/bge-small-en-v1.5"},
53
                                               token=Secret.from_token("<your-api-key>"))
54

55
    print(text_embedder.run("I love pizza!"))
56

57
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
58
    ```
59

60
    #### With self-hosted text embeddings inference
61

62
    ```python
63
    from haystack.components.embedders import HuggingFaceAPITextEmbedder
64
    from haystack.utils import Secret
65

66
    text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
67
                                               api_params={"url": "http://localhost:8080"})
68

69
    print(text_embedder.run("I love pizza!"))
70

71
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
72
    ```
73
    """
74

75
    def __init__(
1✔
76
        self,
77
        api_type: Union[HFEmbeddingAPIType, str],
78
        api_params: Dict[str, str],
79
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
80
        prefix: str = "",
81
        suffix: str = "",
82
        truncate: bool = True,
83
        normalize: bool = False,
84
    ):  # pylint: disable=too-many-positional-arguments
85
        """
86
        Creates a HuggingFaceAPITextEmbedder component.
87

88
        :param api_type:
89
            The type of Hugging Face API to use.
90
        :param api_params:
91
            A dictionary with the following keys:
92
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
93
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
94
            `TEXT_EMBEDDINGS_INFERENCE`.
95
        :param token: The Hugging Face token to use as HTTP bearer authorization.
96
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
97
        :param prefix:
98
            A string to add at the beginning of each text.
99
        :param suffix:
100
            A string to add at the end of each text.
101
        :param truncate:
102
            Truncates the input text to the maximum length supported by the model.
103
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
104
            if the backend uses Text Embeddings Inference.
105
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
106
        :param normalize:
107
            Normalizes the embeddings to unit length.
108
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
109
            if the backend uses Text Embeddings Inference.
110
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
111
        """
112
        huggingface_hub_import.check()
1✔
113

114
        if isinstance(api_type, str):
1✔
115
            api_type = HFEmbeddingAPIType.from_str(api_type)
1✔
116

117
        if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
1✔
118
            model = api_params.get("model")
1✔
119
            if model is None:
1✔
120
                raise ValueError(
1✔
121
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
122
                )
123
            check_valid_model(model, HFModelType.EMBEDDING, token)
1✔
124
            model_or_url = model
1✔
125
        elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
1✔
126
            url = api_params.get("url")
1✔
127
            if url is None:
1✔
128
                msg = (
1✔
129
                    "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
130
                    "parameter in `api_params`."
131
                )
132
                raise ValueError(msg)
1✔
133
            if not is_valid_http_url(url):
1✔
134
                raise ValueError(f"Invalid URL: {url}")
1✔
135
            model_or_url = url
1✔
136
        else:
137
            msg = f"Unknown api_type {api_type}"
×
138
            raise ValueError(msg)
×
139

140
        self.api_type = api_type
1✔
141
        self.api_params = api_params
1✔
142
        self.token = token
1✔
143
        self.prefix = prefix
1✔
144
        self.suffix = suffix
1✔
145
        self.truncate = truncate
1✔
146
        self.normalize = normalize
1✔
147
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
148

149
    def to_dict(self) -> Dict[str, Any]:
1✔
150
        """
151
        Serializes the component to a dictionary.
152

153
        :returns:
154
            Dictionary with serialized data.
155
        """
156
        return default_to_dict(
1✔
157
            self,
158
            api_type=str(self.api_type),
159
            api_params=self.api_params,
160
            prefix=self.prefix,
161
            suffix=self.suffix,
162
            token=self.token.to_dict() if self.token else None,
163
            truncate=self.truncate,
164
            normalize=self.normalize,
165
        )
166

167
    @classmethod
1✔
168
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
1✔
169
        """
170
        Deserializes the component from a dictionary.
171

172
        :param data:
173
            Dictionary to deserialize from.
174
        :returns:
175
            Deserialized component.
176
        """
177
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
178
        return default_from_dict(cls, data)
1✔
179

180
    @component.output_types(embedding=List[float])
1✔
181
    def run(self, text: str):
1✔
182
        """
183
        Embeds a single string.
184

185
        :param text:
186
            Text to embed.
187

188
        :returns:
189
            A dictionary with the following keys:
190
            - `embedding`: The embedding of the input text.
191
        """
192
        if not isinstance(text, str):
1✔
193
            raise TypeError(
1✔
194
                "HuggingFaceAPITextEmbedder expects a string as an input."
195
                "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
196
            )
197

198
        text_to_embed = self.prefix + text + self.suffix
1✔
199

200
        np_embedding = self._client.feature_extraction(
1✔
201
            text=text_to_embed,
202
            # Serverless Inference API does not support truncate and normalize, so we pass None in the request
203
            truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None,
204
            normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None,
205
        )
206

207
        error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
1✔
208
        if np_embedding.ndim > 2:
1✔
209
            raise ValueError(error_msg)
1✔
210
        if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
1✔
211
            raise ValueError(error_msg)
1✔
212

213
        embedding = np_embedding.flatten().tolist()
1✔
214

215
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc