• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13161311307

05 Feb 2025 03:53PM UTC coverage: 91.3% (+0.001%) from 91.299%
13161311307

Pull #8806

github

web-flow
Merge ea14a19b6 into 2828d9e4a
Pull Request #8806: feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode

8952 of 9805 relevant lines covered (91.3%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.3
haystack/components/embedders/sentence_transformers_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Literal, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.components.embedders.backends.sentence_transformers_backend import (
1✔
9
    _SentenceTransformersEmbeddingBackendFactory,
10
)
11
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
12
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
13

14

15
@component
1✔
16
class SentenceTransformersTextEmbedder:
1✔
17
    """
18
    Embeds strings using Sentence Transformers models.
19

20
    You can use it to embed user query and send it to an embedding retriever.
21

22
    Usage example:
23
    ```python
24
    from haystack.components.embedders import SentenceTransformersTextEmbedder
25

26
    text_to_embed = "I love pizza!"
27

28
    text_embedder = SentenceTransformersTextEmbedder()
29
    text_embedder.warm_up()
30

31
    print(text_embedder.run(text_to_embed))
32

33
    # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]}
34
    ```
35
    """
36

37
    def __init__(  # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
1✔
38
        self,
39
        model: str = "sentence-transformers/all-mpnet-base-v2",
40
        device: Optional[ComponentDevice] = None,
41
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
42
        prefix: str = "",
43
        suffix: str = "",
44
        batch_size: int = 32,
45
        progress_bar: bool = True,
46
        normalize_embeddings: bool = False,
47
        trust_remote_code: bool = False,
48
        truncate_dim: Optional[int] = None,
49
        model_kwargs: Optional[Dict[str, Any]] = None,
50
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
51
        config_kwargs: Optional[Dict[str, Any]] = None,
52
        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
53
        encode_kwargs: Optional[Dict[str, Any]] = None,
54
    ):
55
        """
56
        Create a SentenceTransformersTextEmbedder component.
57

58
        :param model:
59
            The model to use for calculating embeddings.
60
            Specify the path to a local model or the ID of the model on Hugging Face.
61
        :param device:
62
            Overrides the default device used to load the model.
63
        :param token:
64
            An API token to use private models from Hugging Face.
65
        :param prefix:
66
            A string to add at the beginning of each text to be embedded.
67
            You can use it to prepend the text with an instruction, as required by some embedding models,
68
            such as E5 and bge.
69
        :param suffix:
70
            A string to add at the end of each text to embed.
71
        :param batch_size:
72
            Number of texts to embed at once.
73
        :param progress_bar:
74
            If `True`, shows a progress bar for calculating embeddings.
75
            If `False`, disables the progress bar.
76
        :param normalize_embeddings:
77
            If `True`, the embeddings are normalized using L2 normalization, so that the embeddings have a norm of 1.
78
        :param trust_remote_code:
79
            If `False`, permits only Hugging Face verified model architectures.
80
            If `True`, permits custom models and scripts.
81
        :param truncate_dim:
82
            The dimension to truncate sentence embeddings to. `None` does no truncation.
83
            If the model has not been trained with Matryoshka Representation Learning,
84
            truncation of embeddings can significantly affect performance.
85
        :param model_kwargs:
86
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
87
            when loading the model. Refer to specific model documentation for available kwargs.
88
        :param tokenizer_kwargs:
89
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
90
            Refer to specific model documentation for available kwargs.
91
        :param config_kwargs:
92
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
93
        :param precision:
94
            The precision to use for the embeddings.
95
            All non-float32 precisions are quantized embeddings.
96
            Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
97
            They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
98
        :param encode_kwargs:
99
            Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
100
            This parameter is provided for fine customization. Be careful not to clash with already set parameters and
101
            avoid passing parameters that change the output type.
102
        """
103

104
        self.model = model
1✔
105
        self.device = ComponentDevice.resolve_device(device)
1✔
106
        self.token = token
1✔
107
        self.prefix = prefix
1✔
108
        self.suffix = suffix
1✔
109
        self.batch_size = batch_size
1✔
110
        self.progress_bar = progress_bar
1✔
111
        self.normalize_embeddings = normalize_embeddings
1✔
112
        self.trust_remote_code = trust_remote_code
1✔
113
        self.truncate_dim = truncate_dim
1✔
114
        self.model_kwargs = model_kwargs
1✔
115
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
116
        self.config_kwargs = config_kwargs
1✔
117
        self.encode_kwargs = encode_kwargs
1✔
118
        self.embedding_backend = None
1✔
119
        self.precision = precision
1✔
120

121
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
122
        """
123
        Data that is sent to Posthog for usage analytics.
124
        """
125
        return {"model": self.model}
×
126

127
    def to_dict(self) -> Dict[str, Any]:
1✔
128
        """
129
        Serializes the component to a dictionary.
130

131
        :returns:
132
            Dictionary with serialized data.
133
        """
134
        serialization_dict = default_to_dict(
1✔
135
            self,
136
            model=self.model,
137
            device=self.device.to_dict(),
138
            token=self.token.to_dict() if self.token else None,
139
            prefix=self.prefix,
140
            suffix=self.suffix,
141
            batch_size=self.batch_size,
142
            progress_bar=self.progress_bar,
143
            normalize_embeddings=self.normalize_embeddings,
144
            trust_remote_code=self.trust_remote_code,
145
            truncate_dim=self.truncate_dim,
146
            model_kwargs=self.model_kwargs,
147
            tokenizer_kwargs=self.tokenizer_kwargs,
148
            config_kwargs=self.config_kwargs,
149
            precision=self.precision,
150
            encode_kwargs=self.encode_kwargs,
151
        )
152
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
153
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
154
        return serialization_dict
1✔
155

156
    @classmethod
1✔
157
    def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
1✔
158
        """
159
        Deserializes the component from a dictionary.
160

161
        :param data:
162
            Dictionary to deserialize from.
163
        :returns:
164
            Deserialized component.
165
        """
166
        init_params = data["init_parameters"]
1✔
167
        if init_params.get("device") is not None:
1✔
168
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
169
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
170
        if init_params.get("model_kwargs") is not None:
1✔
171
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
172
        return default_from_dict(cls, data)
1✔
173

174
    def warm_up(self):
1✔
175
        """
176
        Initializes the component.
177
        """
178
        if self.embedding_backend is None:
1✔
179
            self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
1✔
180
                model=self.model,
181
                device=self.device.to_torch_str(),
182
                auth_token=self.token,
183
                trust_remote_code=self.trust_remote_code,
184
                truncate_dim=self.truncate_dim,
185
                model_kwargs=self.model_kwargs,
186
                tokenizer_kwargs=self.tokenizer_kwargs,
187
                config_kwargs=self.config_kwargs,
188
            )
189
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
190
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
191

192
    @component.output_types(embedding=List[float])
1✔
193
    def run(self, text: str):
1✔
194
        """
195
        Embed a single string.
196

197
        :param text:
198
            Text to embed.
199

200
        :returns:
201
            A dictionary with the following keys:
202
            - `embedding`: The embedding of the input text.
203
        """
204
        if not isinstance(text, str):
1✔
205
            raise TypeError(
1✔
206
                "SentenceTransformersTextEmbedder expects a string as input."
207
                "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
208
            )
209
        if self.embedding_backend is None:
1✔
210
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
×
211

212
        text_to_embed = self.prefix + text + self.suffix
1✔
213
        embedding = self.embedding_backend.embed(
1✔
214
            [text_to_embed],
215
            batch_size=self.batch_size,
216
            show_progress_bar=self.progress_bar,
217
            normalize_embeddings=self.normalize_embeddings,
218
            precision=self.precision,
219
            **(self.encode_kwargs if self.encode_kwargs else {}),
220
        )[0]
221
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc