• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 19114177728

05 Nov 2025 07:41PM UTC coverage: 92.248%. Remained the same
19114177728

Pull #9932

github

web-flow
Merge 3db96ab24 into 510d06361
Pull Request #9932: fix: prompt-builder - jinja2 template set vars still shows required

13531 of 14668 relevant lines covered (92.25%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.49
haystack/components/embedders/sentence_transformers_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Literal, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.components.embedders.backends.sentence_transformers_backend import (
1✔
9
    _SentenceTransformersEmbeddingBackend,
10
    _SentenceTransformersEmbeddingBackendFactory,
11
)
12
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
13
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
14

15

16
@component
1✔
17
class SentenceTransformersTextEmbedder:
1✔
18
    """
19
    Embeds strings using Sentence Transformers models.
20

21
    You can use it to embed user query and send it to an embedding retriever.
22

23
    Usage example:
24
    ```python
25
    from haystack.components.embedders import SentenceTransformersTextEmbedder
26

27
    text_to_embed = "I love pizza!"
28

29
    text_embedder = SentenceTransformersTextEmbedder()
30
    text_embedder.warm_up()
31

32
    print(text_embedder.run(text_to_embed))
33

34
    # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]}
35
    ```
36
    """
37

38
    def __init__(  # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
1✔
39
        self,
40
        model: str = "sentence-transformers/all-mpnet-base-v2",
41
        device: Optional[ComponentDevice] = None,
42
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
43
        prefix: str = "",
44
        suffix: str = "",
45
        batch_size: int = 32,
46
        progress_bar: bool = True,
47
        normalize_embeddings: bool = False,
48
        trust_remote_code: bool = False,
49
        local_files_only: bool = False,
50
        truncate_dim: Optional[int] = None,
51
        model_kwargs: Optional[dict[str, Any]] = None,
52
        tokenizer_kwargs: Optional[dict[str, Any]] = None,
53
        config_kwargs: Optional[dict[str, Any]] = None,
54
        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
55
        encode_kwargs: Optional[dict[str, Any]] = None,
56
        backend: Literal["torch", "onnx", "openvino"] = "torch",
57
        revision: Optional[str] = None,
58
    ):
59
        """
60
        Create a SentenceTransformersTextEmbedder component.
61

62
        :param model:
63
            The model to use for calculating embeddings.
64
            Specify the path to a local model or the ID of the model on Hugging Face.
65
        :param device:
66
            Overrides the default device used to load the model.
67
        :param token:
68
            An API token to use private models from Hugging Face.
69
        :param prefix:
70
            A string to add at the beginning of each text to be embedded.
71
            You can use it to prepend the text with an instruction, as required by some embedding models,
72
            such as E5 and bge.
73
        :param suffix:
74
            A string to add at the end of each text to embed.
75
        :param batch_size:
76
            Number of texts to embed at once.
77
        :param progress_bar:
78
            If `True`, shows a progress bar for calculating embeddings.
79
            If `False`, disables the progress bar.
80
        :param normalize_embeddings:
81
            If `True`, the embeddings are normalized using L2 normalization, so that the embeddings have a norm of 1.
82
        :param trust_remote_code:
83
            If `False`, permits only Hugging Face verified model architectures.
84
            If `True`, permits custom models and scripts.
85
        :param local_files_only:
86
            If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
87
        :param truncate_dim:
88
            The dimension to truncate sentence embeddings to. `None` does no truncation.
89
            If the model has not been trained with Matryoshka Representation Learning,
90
            truncation of embeddings can significantly affect performance.
91
        :param model_kwargs:
92
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
93
            when loading the model. Refer to specific model documentation for available kwargs.
94
        :param tokenizer_kwargs:
95
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
96
            Refer to specific model documentation for available kwargs.
97
        :param config_kwargs:
98
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
99
        :param precision:
100
            The precision to use for the embeddings.
101
            All non-float32 precisions are quantized embeddings.
102
            Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
103
            They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
104
        :param encode_kwargs:
105
            Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
106
            This parameter is provided for fine customization. Be careful not to clash with already set parameters and
107
            avoid passing parameters that change the output type.
108
        :param backend:
109
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
110
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
111
            for more information on acceleration and quantization options.
112
        :param revision:
113
            The specific model version to use. It can be a branch name, a tag name, or a commit id,
114
            for a stored model on Hugging Face.
115
        """
116

117
        self.model = model
1✔
118
        self.device = ComponentDevice.resolve_device(device)
1✔
119
        self.token = token
1✔
120
        self.prefix = prefix
1✔
121
        self.suffix = suffix
1✔
122
        self.batch_size = batch_size
1✔
123
        self.progress_bar = progress_bar
1✔
124
        self.normalize_embeddings = normalize_embeddings
1✔
125
        self.trust_remote_code = trust_remote_code
1✔
126
        self.revision = revision
1✔
127
        self.local_files_only = local_files_only
1✔
128
        self.truncate_dim = truncate_dim
1✔
129
        self.model_kwargs = model_kwargs
1✔
130
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
131
        self.config_kwargs = config_kwargs
1✔
132
        self.encode_kwargs = encode_kwargs
1✔
133
        self.embedding_backend: Optional[_SentenceTransformersEmbeddingBackend] = None
1✔
134
        self.precision = precision
1✔
135
        self.backend = backend
1✔
136

137
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
138
        """
139
        Data that is sent to Posthog for usage analytics.
140
        """
141
        return {"model": self.model}
×
142

143
    def to_dict(self) -> dict[str, Any]:
1✔
144
        """
145
        Serializes the component to a dictionary.
146

147
        :returns:
148
            Dictionary with serialized data.
149
        """
150
        serialization_dict = default_to_dict(
1✔
151
            self,
152
            model=self.model,
153
            device=self.device.to_dict(),
154
            token=self.token.to_dict() if self.token else None,
155
            prefix=self.prefix,
156
            suffix=self.suffix,
157
            batch_size=self.batch_size,
158
            progress_bar=self.progress_bar,
159
            normalize_embeddings=self.normalize_embeddings,
160
            trust_remote_code=self.trust_remote_code,
161
            revision=self.revision,
162
            local_files_only=self.local_files_only,
163
            truncate_dim=self.truncate_dim,
164
            model_kwargs=self.model_kwargs,
165
            tokenizer_kwargs=self.tokenizer_kwargs,
166
            config_kwargs=self.config_kwargs,
167
            precision=self.precision,
168
            encode_kwargs=self.encode_kwargs,
169
            backend=self.backend,
170
        )
171
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
172
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
173
        return serialization_dict
1✔
174

175
    @classmethod
1✔
176
    def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersTextEmbedder":
1✔
177
        """
178
        Deserializes the component from a dictionary.
179

180
        :param data:
181
            Dictionary to deserialize from.
182
        :returns:
183
            Deserialized component.
184
        """
185
        init_params = data["init_parameters"]
1✔
186
        if init_params.get("device") is not None:
1✔
187
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
188
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
189
        if init_params.get("model_kwargs") is not None:
1✔
190
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
191
        return default_from_dict(cls, data)
1✔
192

193
    def warm_up(self):
1✔
194
        """
195
        Initializes the component.
196
        """
197
        if self.embedding_backend is None:
1✔
198
            self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
1✔
199
                model=self.model,
200
                device=self.device.to_torch_str(),
201
                auth_token=self.token,
202
                trust_remote_code=self.trust_remote_code,
203
                revision=self.revision,
204
                local_files_only=self.local_files_only,
205
                truncate_dim=self.truncate_dim,
206
                model_kwargs=self.model_kwargs,
207
                tokenizer_kwargs=self.tokenizer_kwargs,
208
                config_kwargs=self.config_kwargs,
209
                backend=self.backend,
210
            )
211
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
212
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
213

214
    @component.output_types(embedding=list[float])
1✔
215
    def run(self, text: str):
1✔
216
        """
217
        Embed a single string.
218

219
        :param text:
220
            Text to embed.
221

222
        :returns:
223
            A dictionary with the following keys:
224
            - `embedding`: The embedding of the input text.
225
        """
226
        if not isinstance(text, str):
1✔
227
            raise TypeError(
1✔
228
                "SentenceTransformersTextEmbedder expects a string as input."
229
                "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
230
            )
231
        if self.embedding_backend is None:
1✔
232
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
×
233

234
        text_to_embed = self.prefix + text + self.suffix
1✔
235
        embedding = self.embedding_backend.embed(
1✔
236
            [text_to_embed],
237
            batch_size=self.batch_size,
238
            show_progress_bar=self.progress_bar,
239
            normalize_embeddings=self.normalize_embeddings,
240
            precision=self.precision,
241
            **(self.encode_kwargs if self.encode_kwargs else {}),
242
        )[0]
243
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc