• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18592817487

17 Oct 2025 12:33PM UTC coverage: 92.2% (+0.1%) from 92.062%
18592817487

Pull #9859

github

web-flow
Merge f20ff2b98 into a43c47b63
Pull Request #9859: feat: Add FallbackChatGenerator

13346 of 14475 relevant lines covered (92.2%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.43
haystack/components/embedders/sentence_transformers_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Literal, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.components.embedders.backends.sentence_transformers_backend import (
1✔
9
    _SentenceTransformersEmbeddingBackend,
10
    _SentenceTransformersEmbeddingBackendFactory,
11
)
12
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
13
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
14

15

16
@component
1✔
17
class SentenceTransformersTextEmbedder:
1✔
18
    """
19
    Embeds strings using Sentence Transformers models.
20

21
    You can use it to embed user query and send it to an embedding retriever.
22

23
    Usage example:
24
    ```python
25
    from haystack.components.embedders import SentenceTransformersTextEmbedder
26

27
    text_to_embed = "I love pizza!"
28

29
    text_embedder = SentenceTransformersTextEmbedder()
30
    text_embedder.warm_up()
31

32
    print(text_embedder.run(text_to_embed))
33

34
    # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]}
35
    ```
36
    """
37

38
    def __init__(  # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
1✔
39
        self,
40
        model: str = "sentence-transformers/all-mpnet-base-v2",
41
        device: Optional[ComponentDevice] = None,
42
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
43
        prefix: str = "",
44
        suffix: str = "",
45
        batch_size: int = 32,
46
        progress_bar: bool = True,
47
        normalize_embeddings: bool = False,
48
        trust_remote_code: bool = False,
49
        local_files_only: bool = False,
50
        truncate_dim: Optional[int] = None,
51
        model_kwargs: Optional[dict[str, Any]] = None,
52
        tokenizer_kwargs: Optional[dict[str, Any]] = None,
53
        config_kwargs: Optional[dict[str, Any]] = None,
54
        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
55
        encode_kwargs: Optional[dict[str, Any]] = None,
56
        backend: Literal["torch", "onnx", "openvino"] = "torch",
57
    ):
58
        """
59
        Create a SentenceTransformersTextEmbedder component.
60

61
        :param model:
62
            The model to use for calculating embeddings.
63
            Specify the path to a local model or the ID of the model on Hugging Face.
64
        :param device:
65
            Overrides the default device used to load the model.
66
        :param token:
67
            An API token to use private models from Hugging Face.
68
        :param prefix:
69
            A string to add at the beginning of each text to be embedded.
70
            You can use it to prepend the text with an instruction, as required by some embedding models,
71
            such as E5 and bge.
72
        :param suffix:
73
            A string to add at the end of each text to embed.
74
        :param batch_size:
75
            Number of texts to embed at once.
76
        :param progress_bar:
77
            If `True`, shows a progress bar for calculating embeddings.
78
            If `False`, disables the progress bar.
79
        :param normalize_embeddings:
80
            If `True`, the embeddings are normalized using L2 normalization, so that the embeddings have a norm of 1.
81
        :param trust_remote_code:
82
            If `False`, permits only Hugging Face verified model architectures.
83
            If `True`, permits custom models and scripts.
84
        :param local_files_only:
85
            If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
86
        :param truncate_dim:
87
            The dimension to truncate sentence embeddings to. `None` does no truncation.
88
            If the model has not been trained with Matryoshka Representation Learning,
89
            truncation of embeddings can significantly affect performance.
90
        :param model_kwargs:
91
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
92
            when loading the model. Refer to specific model documentation for available kwargs.
93
        :param tokenizer_kwargs:
94
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
95
            Refer to specific model documentation for available kwargs.
96
        :param config_kwargs:
97
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
98
        :param precision:
99
            The precision to use for the embeddings.
100
            All non-float32 precisions are quantized embeddings.
101
            Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
102
            They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
103
        :param encode_kwargs:
104
            Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
105
            This parameter is provided for fine customization. Be careful not to clash with already set parameters and
106
            avoid passing parameters that change the output type.
107
        :param backend:
108
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
109
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
110
            for more information on acceleration and quantization options.
111
        """
112

113
        self.model = model
1✔
114
        self.device = ComponentDevice.resolve_device(device)
1✔
115
        self.token = token
1✔
116
        self.prefix = prefix
1✔
117
        self.suffix = suffix
1✔
118
        self.batch_size = batch_size
1✔
119
        self.progress_bar = progress_bar
1✔
120
        self.normalize_embeddings = normalize_embeddings
1✔
121
        self.trust_remote_code = trust_remote_code
1✔
122
        self.local_files_only = local_files_only
1✔
123
        self.truncate_dim = truncate_dim
1✔
124
        self.model_kwargs = model_kwargs
1✔
125
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
126
        self.config_kwargs = config_kwargs
1✔
127
        self.encode_kwargs = encode_kwargs
1✔
128
        self.embedding_backend: Optional[_SentenceTransformersEmbeddingBackend] = None
1✔
129
        self.precision = precision
1✔
130
        self.backend = backend
1✔
131

132
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
133
        """
134
        Data that is sent to Posthog for usage analytics.
135
        """
136
        return {"model": self.model}
×
137

138
    def to_dict(self) -> dict[str, Any]:
1✔
139
        """
140
        Serializes the component to a dictionary.
141

142
        :returns:
143
            Dictionary with serialized data.
144
        """
145
        serialization_dict = default_to_dict(
1✔
146
            self,
147
            model=self.model,
148
            device=self.device.to_dict(),
149
            token=self.token.to_dict() if self.token else None,
150
            prefix=self.prefix,
151
            suffix=self.suffix,
152
            batch_size=self.batch_size,
153
            progress_bar=self.progress_bar,
154
            normalize_embeddings=self.normalize_embeddings,
155
            trust_remote_code=self.trust_remote_code,
156
            local_files_only=self.local_files_only,
157
            truncate_dim=self.truncate_dim,
158
            model_kwargs=self.model_kwargs,
159
            tokenizer_kwargs=self.tokenizer_kwargs,
160
            config_kwargs=self.config_kwargs,
161
            precision=self.precision,
162
            encode_kwargs=self.encode_kwargs,
163
            backend=self.backend,
164
        )
165
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
166
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
167
        return serialization_dict
1✔
168

169
    @classmethod
1✔
170
    def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersTextEmbedder":
1✔
171
        """
172
        Deserializes the component from a dictionary.
173

174
        :param data:
175
            Dictionary to deserialize from.
176
        :returns:
177
            Deserialized component.
178
        """
179
        init_params = data["init_parameters"]
1✔
180
        if init_params.get("device") is not None:
1✔
181
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
182
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
183
        if init_params.get("model_kwargs") is not None:
1✔
184
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
185
        return default_from_dict(cls, data)
1✔
186

187
    def warm_up(self):
1✔
188
        """
189
        Initializes the component.
190
        """
191
        if self.embedding_backend is None:
1✔
192
            self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
1✔
193
                model=self.model,
194
                device=self.device.to_torch_str(),
195
                auth_token=self.token,
196
                trust_remote_code=self.trust_remote_code,
197
                local_files_only=self.local_files_only,
198
                truncate_dim=self.truncate_dim,
199
                model_kwargs=self.model_kwargs,
200
                tokenizer_kwargs=self.tokenizer_kwargs,
201
                config_kwargs=self.config_kwargs,
202
                backend=self.backend,
203
            )
204
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
205
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
206

207
    @component.output_types(embedding=list[float])
1✔
208
    def run(self, text: str):
1✔
209
        """
210
        Embed a single string.
211

212
        :param text:
213
            Text to embed.
214

215
        :returns:
216
            A dictionary with the following keys:
217
            - `embedding`: The embedding of the input text.
218
        """
219
        if not isinstance(text, str):
1✔
220
            raise TypeError(
1✔
221
                "SentenceTransformersTextEmbedder expects a string as input."
222
                "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
223
            )
224
        if self.embedding_backend is None:
1✔
225
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
×
226

227
        text_to_embed = self.prefix + text + self.suffix
1✔
228
        embedding = self.embedding_backend.embed(
1✔
229
            [text_to_embed],
230
            batch_size=self.batch_size,
231
            show_progress_bar=self.progress_bar,
232
            normalize_embeddings=self.normalize_embeddings,
233
            precision=self.precision,
234
            **(self.encode_kwargs if self.encode_kwargs else {}),
235
        )[0]
236
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc