• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18994388892

01 Nov 2025 08:58AM UTC coverage: 92.246% (+0.002%) from 92.244%
18994388892

Pull #10003

github

web-flow
Merge c08ee9b69 into c27c8e923
Pull Request #10003: feat: add revision parameter to Sentence Transformers embedder components

13502 of 14637 relevant lines covered (92.25%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.01
haystack/components/embedders/sentence_transformers_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import replace
1✔
6
from typing import Any, Literal, Optional
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.components.embedders.backends.sentence_transformers_backend import (
1✔
10
    _SentenceTransformersEmbeddingBackend,
11
    _SentenceTransformersEmbeddingBackendFactory,
12
)
13
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
14
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
15

16

17
@component
1✔
18
class SentenceTransformersDocumentEmbedder:
1✔
19
    """
20
    Calculates document embeddings using Sentence Transformers models.
21

22
    It stores the embeddings in the `embedding` metadata field of each document.
23
    You can also embed documents' metadata.
24
    Use this component in indexing pipelines to embed input documents
25
    and send them to DocumentWriter to write a into a Document Store.
26

27
    ### Usage example:
28

29
    ```python
30
    from haystack import Document
31
    from haystack.components.embedders import SentenceTransformersDocumentEmbedder
32
    doc = Document(content="I love pizza!")
33
    doc_embedder = SentenceTransformersDocumentEmbedder()
34
    doc_embedder.warm_up()
35

36
    result = doc_embedder.run([doc])
37
    print(result['documents'][0].embedding)
38

39
    # [-0.07804739475250244, 0.1498992145061493, ...]
40
    ```
41
    """
42

43
    def __init__(  # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
1✔
44
        self,
45
        model: str = "sentence-transformers/all-mpnet-base-v2",
46
        device: Optional[ComponentDevice] = None,
47
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
48
        prefix: str = "",
49
        suffix: str = "",
50
        batch_size: int = 32,
51
        progress_bar: bool = True,
52
        normalize_embeddings: bool = False,
53
        meta_fields_to_embed: Optional[list[str]] = None,
54
        embedding_separator: str = "\n",
55
        trust_remote_code: bool = False,
56
        revision: Optional[str] = None,
57
        local_files_only: bool = False,
58
        truncate_dim: Optional[int] = None,
59
        model_kwargs: Optional[dict[str, Any]] = None,
60
        tokenizer_kwargs: Optional[dict[str, Any]] = None,
61
        config_kwargs: Optional[dict[str, Any]] = None,
62
        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
63
        encode_kwargs: Optional[dict[str, Any]] = None,
64
        backend: Literal["torch", "onnx", "openvino"] = "torch",
65
    ):
66
        """
67
        Creates a SentenceTransformersDocumentEmbedder component.
68

69
        :param model:
70
            The model to use for calculating embeddings.
71
            Pass a local path or ID of the model on Hugging Face.
72
        :param device:
73
            The device to use for loading the model.
74
            Overrides the default device.
75
        :param token:
76
            The API token to download private models from Hugging Face.
77
        :param prefix:
78
            A string to add at the beginning of each document text.
79
            Can be used to prepend the text with an instruction, as required by some embedding models,
80
            such as E5 and bge.
81
        :param suffix:
82
            A string to add at the end of each document text.
83
        :param batch_size:
84
            Number of documents to embed at once.
85
        :param progress_bar:
86
            If `True`, shows a progress bar when embedding documents.
87
        :param normalize_embeddings:
88
            If `True`, the embeddings are normalized using L2 normalization, so that each embedding has a norm of 1.
89
        :param meta_fields_to_embed:
90
            List of metadata fields to embed along with the document text.
91
        :param embedding_separator:
92
            Separator used to concatenate the metadata fields to the document text.
93
        :param trust_remote_code:
94
            If `False`, allows only Hugging Face verified model architectures.
95
            If `True`, allows custom models and scripts.
96
        :param revision:
97
            The specific model version to use. It can be a branch name, a tag name, or a commit id,
98
            for a stored model on Hugging Face.
99
        :param local_files_only:
100
            If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
101
        :param truncate_dim:
102
            The dimension to truncate sentence embeddings to. `None` does no truncation.
103
            If the model wasn't trained with Matryoshka Representation Learning,
104
            truncating embeddings can significantly affect performance.
105
        :param model_kwargs:
106
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
107
            when loading the model. Refer to specific model documentation for available kwargs.
108
        :param tokenizer_kwargs:
109
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
110
            Refer to specific model documentation for available kwargs.
111
        :param config_kwargs:
112
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
113
        :param precision:
114
            The precision to use for the embeddings.
115
            All non-float32 precisions are quantized embeddings.
116
            Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
117
            They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
118
        :param encode_kwargs:
119
            Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
120
            This parameter is provided for fine customization. Be careful not to clash with already set parameters and
121
            avoid passing parameters that change the output type.
122
        :param backend:
123
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
124
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
125
            for more information on acceleration and quantization options.
126
        """
127

128
        self.model = model
1✔
129
        self.device = ComponentDevice.resolve_device(device)
1✔
130
        self.token = token
1✔
131
        self.prefix = prefix
1✔
132
        self.suffix = suffix
1✔
133
        self.batch_size = batch_size
1✔
134
        self.progress_bar = progress_bar
1✔
135
        self.normalize_embeddings = normalize_embeddings
1✔
136
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
137
        self.embedding_separator = embedding_separator
1✔
138
        self.trust_remote_code = trust_remote_code
1✔
139
        self.revision = revision
1✔
140
        self.local_files_only = local_files_only
1✔
141
        self.truncate_dim = truncate_dim
1✔
142
        self.model_kwargs = model_kwargs
1✔
143
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
144
        self.config_kwargs = config_kwargs
1✔
145
        self.encode_kwargs = encode_kwargs
1✔
146
        self.embedding_backend: Optional[_SentenceTransformersEmbeddingBackend] = None
1✔
147
        self.precision = precision
1✔
148
        self.backend = backend
1✔
149

150
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
151
        """
152
        Data that is sent to Posthog for usage analytics.
153
        """
154
        return {"model": self.model}
×
155

156
    def to_dict(self) -> dict[str, Any]:
1✔
157
        """
158
        Serializes the component to a dictionary.
159

160
        :returns:
161
            Dictionary with serialized data.
162
        """
163
        serialization_dict = default_to_dict(
1✔
164
            self,
165
            model=self.model,
166
            device=self.device.to_dict(),
167
            token=self.token.to_dict() if self.token else None,
168
            prefix=self.prefix,
169
            suffix=self.suffix,
170
            batch_size=self.batch_size,
171
            progress_bar=self.progress_bar,
172
            normalize_embeddings=self.normalize_embeddings,
173
            meta_fields_to_embed=self.meta_fields_to_embed,
174
            embedding_separator=self.embedding_separator,
175
            trust_remote_code=self.trust_remote_code,
176
            revision=self.revision,
177
            local_files_only=self.local_files_only,
178
            truncate_dim=self.truncate_dim,
179
            model_kwargs=self.model_kwargs,
180
            tokenizer_kwargs=self.tokenizer_kwargs,
181
            config_kwargs=self.config_kwargs,
182
            precision=self.precision,
183
            encode_kwargs=self.encode_kwargs,
184
            backend=self.backend,
185
        )
186
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
187
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
188
        return serialization_dict
1✔
189

190
    @classmethod
1✔
191
    def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
1✔
192
        """
193
        Deserializes the component from a dictionary.
194

195
        :param data:
196
            Dictionary to deserialize from.
197
        :returns:
198
            Deserialized component.
199
        """
200
        init_params = data["init_parameters"]
1✔
201
        if init_params.get("device") is not None:
1✔
202
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
203
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
204
        if init_params.get("model_kwargs") is not None:
1✔
205
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
206
        return default_from_dict(cls, data)
1✔
207

208
    def warm_up(self):
1✔
209
        """
210
        Initializes the component.
211
        """
212
        if self.embedding_backend is None:
1✔
213
            self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
1✔
214
                model=self.model,
215
                device=self.device.to_torch_str(),
216
                auth_token=self.token,
217
                trust_remote_code=self.trust_remote_code,
218
                revision=self.revision,
219
                local_files_only=self.local_files_only,
220
                truncate_dim=self.truncate_dim,
221
                model_kwargs=self.model_kwargs,
222
                tokenizer_kwargs=self.tokenizer_kwargs,
223
                config_kwargs=self.config_kwargs,
224
                backend=self.backend,
225
            )
226
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
227
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
228

229
    @component.output_types(documents=list[Document])
1✔
230
    def run(self, documents: list[Document]):
1✔
231
        """
232
        Embed a list of documents.
233

234
        :param documents:
235
            Documents to embed.
236

237
        :returns:
238
            A dictionary with the following keys:
239
            - `documents`: Documents with embeddings.
240
        """
241
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
242
            raise TypeError(
1✔
243
                "SentenceTransformersDocumentEmbedder expects a list of Documents as input."
244
                "In case you want to embed a string, please use the SentenceTransformersTextEmbedder."
245
            )
246
        if self.embedding_backend is None:
1✔
247
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
×
248

249
        texts_to_embed = []
1✔
250
        for doc in documents:
1✔
251
            meta_values_to_embed = [
1✔
252
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
253
            ]
254
            text_to_embed = (
1✔
255
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
256
            )
257
            texts_to_embed.append(text_to_embed)
1✔
258

259
        embeddings = self.embedding_backend.embed(
1✔
260
            texts_to_embed,
261
            batch_size=self.batch_size,
262
            show_progress_bar=self.progress_bar,
263
            normalize_embeddings=self.normalize_embeddings,
264
            precision=self.precision,
265
            **(self.encode_kwargs if self.encode_kwargs else {}),
266
        )
267

268
        new_documents = []
1✔
269
        for doc, emb in zip(documents, embeddings):
1✔
270
            new_documents.append(replace(doc, embedding=emb))
1✔
271

272
        return {"documents": new_documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc