• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18994388892

01 Nov 2025 08:58AM UTC coverage: 92.246% (+0.002%) from 92.244%
18994388892

Pull #10003

github

web-flow
Merge c08ee9b69 into c27c8e923
Pull Request #10003: feat: add revision parameter to Sentence Transformers embedder components

13502 of 14637 relevant lines covered (92.25%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.41
haystack/components/embedders/sentence_transformers_sparse_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import replace
1✔
6
from typing import Any, Literal, Optional
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.components.embedders.backends.sentence_transformers_sparse_backend import (
1✔
10
    _SentenceTransformersSparseEmbeddingBackendFactory,
11
    _SentenceTransformersSparseEncoderEmbeddingBackend,
12
)
13
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
14
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
15

16

17
@component
1✔
18
class SentenceTransformersSparseDocumentEmbedder:
1✔
19
    """
20
    Calculates document sparse embeddings using sparse embedding models from Sentence Transformers.
21

22
    It stores the sparse embeddings in the `sparse_embedding` metadata field of each document.
23
    You can also embed documents' metadata.
24
    Use this component in indexing pipelines to embed input documents
25
    and send them to DocumentWriter to write a into a Document Store.
26

27
    ### Usage example:
28

29
    ```python
30
    from haystack import Document
31
    from haystack.components.embedders import SentenceTransformersSparseDocumentEmbedder
32

33
    doc = Document(content="I love pizza!")
34
    doc_embedder = SentenceTransformersSparseDocumentEmbedder()
35
    doc_embedder.warm_up()
36

37
    result = doc_embedder.run([doc])
38
    print(result['documents'][0].sparse_embedding)
39

40
    # SparseEmbedding(indices=[999, 1045, ...], values=[0.918, 0.867, ...])
41
    ```
42
    """
43

44
    def __init__(  # noqa: PLR0913
1✔
45
        self,
46
        *,
47
        model: str = "prithivida/Splade_PP_en_v2",
48
        device: Optional[ComponentDevice] = None,
49
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
50
        prefix: str = "",
51
        suffix: str = "",
52
        batch_size: int = 32,
53
        progress_bar: bool = True,
54
        meta_fields_to_embed: Optional[list[str]] = None,
55
        embedding_separator: str = "\n",
56
        trust_remote_code: bool = False,
57
        revision: Optional[str] = None,
58
        local_files_only: bool = False,
59
        model_kwargs: Optional[dict[str, Any]] = None,
60
        tokenizer_kwargs: Optional[dict[str, Any]] = None,
61
        config_kwargs: Optional[dict[str, Any]] = None,
62
        backend: Literal["torch", "onnx", "openvino"] = "torch",
63
    ):
64
        """
65
        Creates a SentenceTransformersSparseDocumentEmbedder component.
66

67
        :param model:
68
            The model to use for calculating sparse embeddings.
69
            Pass a local path or ID of the model on Hugging Face.
70
        :param device:
71
            The device to use for loading the model.
72
            Overrides the default device.
73
        :param token:
74
            The API token to download private models from Hugging Face.
75
        :param prefix:
76
            A string to add at the beginning of each document text.
77
        :param suffix:
78
            A string to add at the end of each document text.
79
        :param batch_size:
80
            Number of documents to embed at once.
81
        :param progress_bar:
82
            If `True`, shows a progress bar when embedding documents.
83
        :param meta_fields_to_embed:
84
            List of metadata fields to embed along with the document text.
85
        :param embedding_separator:
86
            Separator used to concatenate the metadata fields to the document text.
87
        :param trust_remote_code:
88
            If `False`, allows only Hugging Face verified model architectures.
89
            If `True`, allows custom models and scripts.
90
        :param revision:
91
            The specific model version to use. It can be a branch name, a tag name, or a commit id,
92
            for a stored model on Hugging Face.
93
        :param local_files_only:
94
            If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
95
        :param model_kwargs:
96
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
97
            when loading the model. Refer to specific model documentation for available kwargs.
98
        :param tokenizer_kwargs:
99
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
100
            Refer to specific model documentation for available kwargs.
101
        :param config_kwargs:
102
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
103
        :param backend:
104
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
105
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
106
            for more information on acceleration and quantization options.
107
        """
108

109
        self.model = model
1✔
110
        self.device = ComponentDevice.resolve_device(device)
1✔
111
        self.token = token
1✔
112
        self.prefix = prefix
1✔
113
        self.suffix = suffix
1✔
114
        self.batch_size = batch_size
1✔
115
        self.progress_bar = progress_bar
1✔
116
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
117
        self.embedding_separator = embedding_separator
1✔
118
        self.trust_remote_code = trust_remote_code
1✔
119
        self.revision = revision
1✔
120
        self.local_files_only = local_files_only
1✔
121
        self.model_kwargs = model_kwargs
1✔
122
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
123
        self.config_kwargs = config_kwargs
1✔
124
        self.embedding_backend: Optional[_SentenceTransformersSparseEncoderEmbeddingBackend] = None
1✔
125
        self.backend = backend
1✔
126

127
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
128
        """
129
        Data that is sent to Posthog for usage analytics.
130
        """
131
        return {"model": self.model}
×
132

133
    def to_dict(self) -> dict[str, Any]:
1✔
134
        """
135
        Serializes the component to a dictionary.
136

137
        :returns:
138
            Dictionary with serialized data.
139
        """
140
        serialization_dict = default_to_dict(
1✔
141
            self,
142
            model=self.model,
143
            device=self.device.to_dict(),
144
            token=self.token.to_dict() if self.token else None,
145
            prefix=self.prefix,
146
            suffix=self.suffix,
147
            batch_size=self.batch_size,
148
            progress_bar=self.progress_bar,
149
            meta_fields_to_embed=self.meta_fields_to_embed,
150
            embedding_separator=self.embedding_separator,
151
            trust_remote_code=self.trust_remote_code,
152
            revision=self.revision,
153
            local_files_only=self.local_files_only,
154
            model_kwargs=self.model_kwargs,
155
            tokenizer_kwargs=self.tokenizer_kwargs,
156
            config_kwargs=self.config_kwargs,
157
            backend=self.backend,
158
        )
159
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
160
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
161
        return serialization_dict
1✔
162

163
    @classmethod
1✔
164
    def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersSparseDocumentEmbedder":
1✔
165
        """
166
        Deserializes the component from a dictionary.
167

168
        :param data:
169
            Dictionary to deserialize from.
170
        :returns:
171
            Deserialized component.
172
        """
173
        init_params = data["init_parameters"]
1✔
174
        if init_params.get("device") is not None:
1✔
175
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
176
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
177
        if init_params.get("model_kwargs") is not None:
1✔
178
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
179
        return default_from_dict(cls, data)
1✔
180

181
    def warm_up(self):
1✔
182
        """
183
        Initializes the component.
184
        """
185
        if self.embedding_backend is None:
1✔
186
            self.embedding_backend = _SentenceTransformersSparseEmbeddingBackendFactory.get_embedding_backend(
1✔
187
                model=self.model,
188
                device=self.device.to_torch_str(),
189
                auth_token=self.token,
190
                trust_remote_code=self.trust_remote_code,
191
                revision=self.revision,
192
                local_files_only=self.local_files_only,
193
                model_kwargs=self.model_kwargs,
194
                tokenizer_kwargs=self.tokenizer_kwargs,
195
                config_kwargs=self.config_kwargs,
196
                backend=self.backend,
197
            )
198
            if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
1✔
199
                self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
1✔
200

201
    @component.output_types(documents=list[Document])
1✔
202
    def run(self, documents: list[Document]):
1✔
203
        """
204
        Embed a list of documents.
205

206
        :param documents:
207
            Documents to embed.
208

209
        :returns:
210
            A dictionary with the following keys:
211
            - `documents`: Documents with sparse embeddings under the `sparse_embedding` field.
212
        """
213
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
214
            raise TypeError(
1✔
215
                "SentenceTransformersSparseDocumentEmbedder expects a list of Documents as input."
216
                "In case you want to embed a list of strings, please use the SentenceTransformersSparseTextEmbedder."
217
            )
218
        if self.embedding_backend is None:
1✔
219
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
1✔
220

221
        texts_to_embed = []
1✔
222
        for doc in documents:
1✔
223
            meta_values_to_embed = [
1✔
224
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
225
            ]
226
            text_to_embed = (
1✔
227
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
228
            )
229
            texts_to_embed.append(text_to_embed)
1✔
230

231
        embeddings = self.embedding_backend.embed(
1✔
232
            data=texts_to_embed, batch_size=self.batch_size, show_progress_bar=self.progress_bar
233
        )
234

235
        documents_with_embeddings = []
1✔
236
        for doc, emb in zip(documents, embeddings):
1✔
237
            documents_with_embeddings.append(replace(doc, sparse_embedding=emb))
1✔
238

239
        return {"documents": documents_with_embeddings}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc