• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 16840250047

08 Aug 2025 08:49PM UTC coverage: 91.946% (-0.009%) from 91.955%
16840250047

Pull #9693

github

web-flow
Merge ec42622ce into 683c935b3
Pull Request #9693: fix: prevent in-place mutation of documents after embeddings by using deepcopy

12809 of 13931 relevant lines covered (91.95%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.52
haystack/components/embedders/hugging_face_api_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import replace
1✔
6
from typing import Any, Optional, Union
1✔
7

8
from tqdm import tqdm
1✔
9
from tqdm.asyncio import tqdm as async_tqdm
1✔
10

11
from haystack import component, default_from_dict, default_to_dict, logging
1✔
12
from haystack.dataclasses import Document
1✔
13
from haystack.lazy_imports import LazyImport
1✔
14
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
15
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
1✔
16
from haystack.utils.url_validation import is_valid_http_url
1✔
17

18
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1✔
19
    from huggingface_hub import AsyncInferenceClient, InferenceClient
1✔
20

21
logger = logging.getLogger(__name__)
1✔
22

23

24
@component
1✔
25
class HuggingFaceAPIDocumentEmbedder:
1✔
26
    """
27
    Embeds documents using Hugging Face APIs.
28

29
    Use it with the following Hugging Face APIs:
30
    - [Free Serverless Inference API](https://huggingface.co/inference-api)
31
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
32
    - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
33

34

35
    ### Usage examples
36

37
    #### With free serverless inference API
38

39
    ```python
40
    from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
41
    from haystack.utils import Secret
42
    from haystack.dataclasses import Document
43

44
    doc = Document(content="I love pizza!")
45

46
    doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
47
                                                  api_params={"model": "BAAI/bge-small-en-v1.5"},
48
                                                  token=Secret.from_token("<your-api-key>"))
49

50
    result = document_embedder.run([doc])
51
    print(result["documents"][0].embedding)
52

53
    # [0.017020374536514282, -0.023255806416273117, ...]
54
    ```
55

56
    #### With paid inference endpoints
57

58
    ```python
59
    from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
60
    from haystack.utils import Secret
61
    from haystack.dataclasses import Document
62

63
    doc = Document(content="I love pizza!")
64

65
    doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
66
                                                  api_params={"url": "<your-inference-endpoint-url>"},
67
                                                  token=Secret.from_token("<your-api-key>"))
68

69
    result = document_embedder.run([doc])
70
    print(result["documents"][0].embedding)
71

72
    # [0.017020374536514282, -0.023255806416273117, ...]
73
    ```
74

75
    #### With self-hosted text embeddings inference
76

77
    ```python
78
    from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
79
    from haystack.dataclasses import Document
80

81
    doc = Document(content="I love pizza!")
82

83
    doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
84
                                                  api_params={"url": "http://localhost:8080"})
85

86
    result = document_embedder.run([doc])
87
    print(result["documents"][0].embedding)
88

89
    # [0.017020374536514282, -0.023255806416273117, ...]
90
    ```
91
    """
92

93
    def __init__(
1✔
94
        self,
95
        api_type: Union[HFEmbeddingAPIType, str],
96
        api_params: dict[str, str],
97
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
98
        prefix: str = "",
99
        suffix: str = "",
100
        truncate: Optional[bool] = True,
101
        normalize: Optional[bool] = False,
102
        batch_size: int = 32,
103
        progress_bar: bool = True,
104
        meta_fields_to_embed: Optional[list[str]] = None,
105
        embedding_separator: str = "\n",
106
    ):  # pylint: disable=too-many-positional-arguments
107
        """
108
        Creates a HuggingFaceAPIDocumentEmbedder component.
109

110
        :param api_type:
111
            The type of Hugging Face API to use.
112
        :param api_params:
113
            A dictionary with the following keys:
114
            - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
115
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
116
            `TEXT_EMBEDDINGS_INFERENCE`.
117
        :param token: The Hugging Face token to use as HTTP bearer authorization.
118
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
119
        :param prefix:
120
            A string to add at the beginning of each text.
121
        :param suffix:
122
            A string to add at the end of each text.
123
        :param truncate:
124
            Truncates the input text to the maximum length supported by the model.
125
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
126
            if the backend uses Text Embeddings Inference.
127
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
128
        :param normalize:
129
            Normalizes the embeddings to unit length.
130
            Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
131
            if the backend uses Text Embeddings Inference.
132
            If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
133
        :param batch_size:
134
            Number of documents to process at once.
135
        :param progress_bar:
136
            If `True`, shows a progress bar when running.
137
        :param meta_fields_to_embed:
138
            List of metadata fields to embed along with the document text.
139
        :param embedding_separator:
140
            Separator used to concatenate the metadata fields to the document text.
141
        """
142
        huggingface_hub_import.check()
1✔
143

144
        if isinstance(api_type, str):
1✔
145
            api_type = HFEmbeddingAPIType.from_str(api_type)
1✔
146

147
        api_params = api_params or {}
1✔
148

149
        if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
1✔
150
            model = api_params.get("model")
1✔
151
            if model is None:
1✔
152
                raise ValueError(
1✔
153
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
154
                )
155
            check_valid_model(model, HFModelType.EMBEDDING, token)
1✔
156
            model_or_url = model
1✔
157
        elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
1✔
158
            url = api_params.get("url")
1✔
159
            if url is None:
1✔
160
                msg = (
1✔
161
                    "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
162
                    "parameter in `api_params`."
163
                )
164
                raise ValueError(msg)
1✔
165
            if not is_valid_http_url(url):
1✔
166
                raise ValueError(f"Invalid URL: {url}")
1✔
167
            model_or_url = url
1✔
168
        else:
169
            msg = f"Unknown api_type {api_type}"
×
170
            raise ValueError(msg)
×
171

172
        client_args: dict[str, Any] = {"model": model_or_url, "token": token.resolve_value() if token else None}
1✔
173

174
        self.api_type = api_type
1✔
175
        self.api_params = api_params
1✔
176
        self.token = token
1✔
177
        self.prefix = prefix
1✔
178
        self.suffix = suffix
1✔
179
        self.truncate = truncate
1✔
180
        self.normalize = normalize
1✔
181
        self.batch_size = batch_size
1✔
182
        self.progress_bar = progress_bar
1✔
183
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
184
        self.embedding_separator = embedding_separator
1✔
185
        self._client = InferenceClient(**client_args)
1✔
186
        self._async_client = AsyncInferenceClient(**client_args)
1✔
187

188
    def to_dict(self) -> dict[str, Any]:
1✔
189
        """
190
        Serializes the component to a dictionary.
191

192
        :returns:
193
            Dictionary with serialized data.
194
        """
195
        return default_to_dict(
1✔
196
            self,
197
            api_type=str(self.api_type),
198
            api_params=self.api_params,
199
            prefix=self.prefix,
200
            suffix=self.suffix,
201
            token=self.token.to_dict() if self.token else None,
202
            truncate=self.truncate,
203
            normalize=self.normalize,
204
            batch_size=self.batch_size,
205
            progress_bar=self.progress_bar,
206
            meta_fields_to_embed=self.meta_fields_to_embed,
207
            embedding_separator=self.embedding_separator,
208
        )
209

210
    @classmethod
1✔
211
    def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
1✔
212
        """
213
        Deserializes the component from a dictionary.
214

215
        :param data:
216
            Dictionary to deserialize from.
217
        :returns:
218
            Deserialized component.
219
        """
220
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
221
        return default_from_dict(cls, data)
1✔
222

223
    def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
1✔
224
        """
225
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
226
        """
227
        texts_to_embed = []
1✔
228
        for doc in documents:
1✔
229
            meta_values_to_embed = [
1✔
230
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
231
            ]
232

233
            text_to_embed = (
1✔
234
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
235
            )
236

237
            texts_to_embed.append(text_to_embed)
1✔
238
        return texts_to_embed
1✔
239

240
    @staticmethod
1✔
241
    def _adjust_api_parameters(
1✔
242
        truncate: Optional[bool], normalize: Optional[bool], api_type: HFEmbeddingAPIType
243
    ) -> tuple[Optional[bool], Optional[bool]]:
244
        """
245
        Adjust the truncate and normalize parameters based on the API type.
246
        """
247
        if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
1✔
248
            if truncate is not None:
1✔
249
                msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
1✔
250
                logger.warning(msg)
1✔
251
                truncate = None
1✔
252
            if normalize is not None:
1✔
253
                msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
1✔
254
                logger.warning(msg)
1✔
255
                normalize = None
1✔
256
        return truncate, normalize
1✔
257

258
    def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
1✔
259
        """
260
        Embed a list of texts in batches.
261
        """
262
        truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
1✔
263

264
        all_embeddings: list = []
1✔
265
        for i in tqdm(
1✔
266
            range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
267
        ):
268
            batch = texts_to_embed[i : i + batch_size]
1✔
269

270
            np_embeddings = self._client.feature_extraction(
1✔
271
                # this method does not officially support list of strings, but works as expected
272
                text=batch,  # type: ignore[arg-type]
273
                truncate=truncate,
274
                normalize=normalize,
275
            )
276

277
            if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
1✔
278
                raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")
1✔
279

280
            all_embeddings.extend(np_embeddings.tolist())
1✔
281

282
        return all_embeddings
1✔
283

284
    async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
1✔
285
        """
286
        Embed a list of texts in batches asynchronously.
287
        """
288
        truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
1✔
289

290
        all_embeddings: list = []
1✔
291
        for i in async_tqdm(
1✔
292
            range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
293
        ):
294
            batch = texts_to_embed[i : i + batch_size]
1✔
295

296
            np_embeddings = await self._async_client.feature_extraction(
1✔
297
                # this method does not officially support list of strings, but works as expected
298
                text=batch,  # type: ignore[arg-type]
299
                truncate=truncate,
300
                normalize=normalize,
301
            )
302

303
            if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
1✔
304
                raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")
1✔
305

306
            all_embeddings.extend(np_embeddings.tolist())
1✔
307

308
        return all_embeddings
1✔
309

310
    @component.output_types(documents=list[Document])
1✔
311
    def run(self, documents: list[Document]):
1✔
312
        """
313
        Embeds a list of documents.
314

315
        :param documents:
316
            Documents to embed.
317

318
        :returns:
319
            A dictionary with the following keys:
320
            - `documents`: A list of documents with embeddings.
321
        """
322
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
323
            raise TypeError(
×
324
                "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
325
                " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
326
            )
327

328
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
1✔
329

330
        embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
1✔
331

332
        new_documents = []
1✔
333
        for doc, emb in zip(documents, embeddings):
1✔
334
            new_documents.append(replace(doc, embedding=emb))
1✔
335

336
        return {"documents": new_documents}
1✔
337

338
    @component.output_types(documents=list[Document])
1✔
339
    async def run_async(self, documents: list[Document]):
1✔
340
        """
341
        Embeds a list of documents asynchronously.
342

343
        :param documents:
344
            Documents to embed.
345

346
        :returns:
347
            A dictionary with the following keys:
348
            - `documents`: A list of documents with embeddings.
349
        """
350
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
351
            raise TypeError(
×
352
                "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
353
                " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
354
            )
355

356
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
1✔
357

358
        embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
1✔
359

360
        new_documents = []
1✔
361
        for doc, emb in zip(documents, embeddings):
1✔
362
            new_documents.append(replace(doc, embedding=emb))
1✔
363

364
        return {"documents": new_documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc