• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18994388892

01 Nov 2025 08:58AM UTC coverage: 92.246% (+0.002%) from 92.244%
18994388892

Pull #10003

github

web-flow
Merge c08ee9b69 into c27c8e923
Pull Request #10003: feat: add revision parameter to Sentence Transformers embedder components

13502 of 14637 relevant lines covered (92.25%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.08
haystack/components/embedders/sentence_transformers_sparse_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Literal, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.components.embedders.backends.sentence_transformers_sparse_backend import (
1✔
9
    _SentenceTransformersSparseEmbeddingBackendFactory,
10
    _SentenceTransformersSparseEncoderEmbeddingBackend,
11
)
12
from haystack.dataclasses.sparse_embedding import SparseEmbedding
1✔
13
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
14
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
15

16

17
@component
1✔
18
class SentenceTransformersSparseTextEmbedder:
1✔
19
    """
20
    Embeds strings using sparse embedding models from Sentence Transformers.
21

22
    You can use it to embed user query and send it to a sparse embedding retriever.
23

24
    Usage example:
25
    ```python
26
    from haystack.components.embedders import SentenceTransformersSparseTextEmbedder
27

28
    text_to_embed = "I love pizza!"
29

30
    text_embedder = SentenceTransformersSparseTextEmbedder()
31
    text_embedder.warm_up()
32

33
    print(text_embedder.run(text_to_embed))
34

35
    # {'sparse_embedding': SparseEmbedding(indices=[999, 1045, ...], values=[0.918, 0.867, ...])}
36
    ```
37
    """
38

39
    def __init__(  # noqa: PLR0913
1✔
40
        self,
41
        *,
42
        model: str = "prithivida/Splade_PP_en_v2",
43
        device: Optional[ComponentDevice] = None,
44
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
45
        prefix: str = "",
46
        suffix: str = "",
47
        trust_remote_code: bool = False,
48
        revision: Optional[str] = None,
49
        local_files_only: bool = False,
50
        model_kwargs: Optional[dict[str, Any]] = None,
51
        tokenizer_kwargs: Optional[dict[str, Any]] = None,
52
        config_kwargs: Optional[dict[str, Any]] = None,
53
        encode_kwargs: Optional[dict[str, Any]] = None,
54
        backend: Literal["torch", "onnx", "openvino"] = "torch",
55
    ):
56
        """
57
        Create a SentenceTransformersSparseTextEmbedder component.
58

59
        :param model:
60
            The model to use for calculating sparse embeddings.
61
            Specify the path to a local model or the ID of the model on Hugging Face.
62
        :param device:
63
            Overrides the default device used to load the model.
64
        :param token:
65
            An API token to use private models from Hugging Face.
66
        :param prefix:
67
            A string to add at the beginning of each text to be embedded.
68
        :param suffix:
69
            A string to add at the end of each text to embed.
70
        :param trust_remote_code:
71
            If `False`, permits only Hugging Face verified model architectures.
72
            If `True`, permits custom models and scripts.
73
        :param revision:
74
            The specific model version to use. It can be a branch name, a tag name, or a commit id,
75
            for a stored model on Hugging Face.
76
        :param local_files_only:
77
            If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
78
        :param model_kwargs:
79
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
80
            when loading the model. Refer to specific model documentation for available kwargs.
81
        :param tokenizer_kwargs:
82
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
83
            Refer to specific model documentation for available kwargs.
84
        :param config_kwargs:
85
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
86
        :param backend:
87
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
88
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
89
            for more information on acceleration and quantization options.
90
        """
91

92
        self.model = model
1✔
93
        self.device = ComponentDevice.resolve_device(device)
1✔
94
        self.token = token
1✔
95
        self.prefix = prefix
1✔
96
        self.suffix = suffix
1✔
97
        self.trust_remote_code = trust_remote_code
1✔
98
        self.revision = revision
1✔
99
        self.local_files_only = local_files_only
1✔
100
        self.model_kwargs = model_kwargs
1✔
101
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
102
        self.config_kwargs = config_kwargs
1✔
103
        self.embedding_backend: Optional[_SentenceTransformersSparseEncoderEmbeddingBackend] = None
1✔
104
        self.backend = backend
1✔
105

106
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
107
        """
108
        Data that is sent to Posthog for usage analytics.
109
        """
110
        return {"model": self.model}
×
111

112
    def to_dict(self) -> dict[str, Any]:
1✔
113
        """
114
        Serializes the component to a dictionary.
115

116
        :returns:
117
            Dictionary with serialized data.
118
        """
119
        serialization_dict = default_to_dict(
1✔
120
            self,
121
            model=self.model,
122
            device=self.device.to_dict(),
123
            token=self.token.to_dict() if self.token else None,
124
            prefix=self.prefix,
125
            suffix=self.suffix,
126
            trust_remote_code=self.trust_remote_code,
127
            revision=self.revision,
128
            local_files_only=self.local_files_only,
129
            model_kwargs=self.model_kwargs,
130
            tokenizer_kwargs=self.tokenizer_kwargs,
131
            config_kwargs=self.config_kwargs,
132
            backend=self.backend,
133
        )
134
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
135
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
136
        return serialization_dict
1✔
137

138
    @classmethod
1✔
139
    def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersSparseTextEmbedder":
1✔
140
        """
141
        Deserializes the component from a dictionary.
142

143
        :param data:
144
            Dictionary to deserialize from.
145
        :returns:
146
            Deserialized component.
147
        """
148
        init_params = data["init_parameters"]
1✔
149
        if init_params.get("device") is not None:
1✔
150
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
151
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
152
        if init_params.get("model_kwargs") is not None:
1✔
153
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
154
        return default_from_dict(cls, data)
1✔
155

156
    def warm_up(self):
1✔
157
        """
158
        Initializes the component.
159
        """
160
        if self.embedding_backend is None:
1✔
161
            self.embedding_backend = _SentenceTransformersSparseEmbeddingBackendFactory.get_embedding_backend(
1✔
162
                model=self.model,
163
                device=self.device.to_torch_str(),
164
                auth_token=self.token,
165
                trust_remote_code=self.trust_remote_code,
166
                revision=self.revision,
167
                local_files_only=self.local_files_only,
168
                model_kwargs=self.model_kwargs,
169
                tokenizer_kwargs=self.tokenizer_kwargs,
170
                config_kwargs=self.config_kwargs,
171
                backend=self.backend,
172
            )
173
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
174
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
175

176
    @component.output_types(sparse_embedding=SparseEmbedding)
1✔
177
    def run(self, text: str):
1✔
178
        """
179
        Embed a single string.
180

181
        :param text:
182
            Text to embed.
183

184
        :returns:
185
            A dictionary with the following keys:
186
            - `sparse_embedding`: The sparse embedding of the input text.
187
        """
188
        if not isinstance(text, str):
1✔
189
            raise TypeError(
1✔
190
                "SentenceTransformersSparseTextEmbedder expects a string as input."
191
                "In case you want to embed a list of Documents, please use the"
192
                "SentenceTransformersSparseDocumentEmbedder."
193
            )
194
        if self.embedding_backend is None:
1✔
195
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
1✔
196

197
        text_to_embed = self.prefix + text + self.suffix
1✔
198

199
        sparse_embedding = self.embedding_backend.embed(data=[text_to_embed])[0]
1✔
200

201
        return {"sparse_embedding": sparse_embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc