• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15131674881

20 May 2025 07:35AM UTC coverage: 90.156% (-0.3%) from 90.471%
15131674881

Pull #9407

github

web-flow
Merge b382eca10 into 6ad23f822
Pull Request #9407: feat: stream `ToolResult` from run_async in Agent

10972 of 12170 relevant lines covered (90.16%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.43
haystack/components/rankers/transformers_similarity.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from pathlib import Path
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
1✔
11
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
1✔
12

13
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
14
    import accelerate  # pylint: disable=unused-import # the library is used but not directly referenced
1✔
15
    import torch
1✔
16
    from torch.utils.data import DataLoader, Dataset
1✔
17
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
1✔
18

19

20
@component
1✔
21
class TransformersSimilarityRanker:
1✔
22
    """
23
    Ranks documents based on their semantic similarity to the query.
24

25
    It uses a pre-trained cross-encoder model from Hugging Face to embed the query and the documents.
26

27
    ### Usage example
28

29
    ```python
30
    from haystack import Document
31
    from haystack.components.rankers import TransformersSimilarityRanker
32

33
    ranker = TransformersSimilarityRanker()
34
    docs = [Document(content="Paris"), Document(content="Berlin")]
35
    query = "City in Germany"
36
    ranker.warm_up()
37
    result = ranker.run(query=query, documents=docs)
38
    docs = result["documents"]
39
    print(docs[0].content)
40
    ```
41
    """
42

43
    def __init__(  # noqa: PLR0913, pylint: disable=too-many-positional-arguments
1✔
44
        self,
45
        model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
46
        device: Optional[ComponentDevice] = None,
47
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
48
        top_k: int = 10,
49
        query_prefix: str = "",
50
        document_prefix: str = "",
51
        meta_fields_to_embed: Optional[List[str]] = None,
52
        embedding_separator: str = "\n",
53
        scale_score: bool = True,
54
        calibration_factor: Optional[float] = 1.0,
55
        score_threshold: Optional[float] = None,
56
        model_kwargs: Optional[Dict[str, Any]] = None,
57
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
58
        batch_size: int = 16,
59
    ):
60
        """
61
        Creates an instance of TransformersSimilarityRanker.
62

63
        :param model:
64
            The ranking model. Pass a local path or the Hugging Face model name of a cross-encoder model.
65
        :param device:
66
            The device on which the model is loaded. If `None`, overrides the default device.
67
        :param token:
68
            The API token to download private models from Hugging Face.
69
        :param top_k:
70
            The maximum number of documents to return per query.
71
        :param query_prefix:
72
            A string to add at the beginning of the query text before ranking.
73
            Use it to prepend the text with an instruction, as required by reranking models like `bge`.
74
        :param document_prefix:
75
            A string to add at the beginning of each document before ranking. You can use it to prepend the document
76
            with an instruction, as required by embedding models like `bge`.
77
        :param meta_fields_to_embed:
78
            List of metadata fields to embed with the document.
79
        :param embedding_separator:
80
            Separator to concatenate metadata fields to the document.
81
        :param scale_score:
82
            If `True`, scales the raw logit predictions using a Sigmoid activation function.
83
            If `False`, disables scaling of the raw logit predictions.
84
        :param calibration_factor:
85
            Use this factor to calibrate probabilities with `sigmoid(logits * calibration_factor)`.
86
            Used only if `scale_score` is `True`.
87
        :param score_threshold:
88
            Use it to return documents with a score above this threshold only.
89
        :param model_kwargs:
90
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
91
            when loading the model. Refer to specific model documentation for available kwargs.
92
        :param tokenizer_kwargs:
93
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
94
            Refer to specific model documentation for available kwargs.
95
        :param batch_size:
96
            The batch size to use for inference. The higher the batch size, the more memory is required.
97
            If you run into memory issues, reduce the batch size.
98

99
        :raises ValueError:
100
            If `top_k` is not > 0.
101
            If `scale_score` is True and `calibration_factor` is not provided.
102
        """
103
        torch_and_transformers_import.check()
1✔
104

105
        self.model_name_or_path = str(model)
1✔
106
        self.model = None
1✔
107
        self.query_prefix = query_prefix
1✔
108
        self.document_prefix = document_prefix
1✔
109
        self.tokenizer = None
1✔
110
        self.device = None
1✔
111
        self.top_k = top_k
1✔
112
        self.token = token
1✔
113
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
114
        self.embedding_separator = embedding_separator
1✔
115
        self.scale_score = scale_score
1✔
116
        self.calibration_factor = calibration_factor
1✔
117
        self.score_threshold = score_threshold
1✔
118

119
        model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
1✔
120
        self.model_kwargs = model_kwargs
1✔
121
        self.tokenizer_kwargs = tokenizer_kwargs or {}
1✔
122
        self.batch_size = batch_size
1✔
123

124
        # Parameter validation
125
        if self.scale_score and self.calibration_factor is None:
1✔
126
            raise ValueError(
×
127
                f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
128
            )
129

130
        if self.top_k <= 0:
1✔
131
            raise ValueError(f"top_k must be > 0, but got {top_k}")
×
132

133
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
134
        """
135
        Data that is sent to Posthog for usage analytics.
136
        """
137
        return {"model": self.model_name_or_path}
×
138

139
    def warm_up(self):
1✔
140
        """
141
        Initializes the component.
142
        """
143
        if self.model is None:
1✔
144
            self.model = AutoModelForSequenceClassification.from_pretrained(
1✔
145
                self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
146
            )
147
            self.tokenizer = AutoTokenizer.from_pretrained(
1✔
148
                self.model_name_or_path,
149
                token=self.token.resolve_value() if self.token else None,
150
                **self.tokenizer_kwargs,
151
            )
152
            assert self.model is not None
1✔
153
            self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
1✔
154

155
    def to_dict(self) -> Dict[str, Any]:
1✔
156
        """
157
        Serializes the component to a dictionary.
158

159
        :returns:
160
            Dictionary with serialized data.
161
        """
162
        serialization_dict = default_to_dict(
1✔
163
            self,
164
            device=None,
165
            model=self.model_name_or_path,
166
            token=self.token.to_dict() if self.token else None,
167
            top_k=self.top_k,
168
            query_prefix=self.query_prefix,
169
            document_prefix=self.document_prefix,
170
            meta_fields_to_embed=self.meta_fields_to_embed,
171
            embedding_separator=self.embedding_separator,
172
            scale_score=self.scale_score,
173
            calibration_factor=self.calibration_factor,
174
            score_threshold=self.score_threshold,
175
            model_kwargs=self.model_kwargs,
176
            tokenizer_kwargs=self.tokenizer_kwargs,
177
            batch_size=self.batch_size,
178
        )
179

180
        serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
1✔
181
        return serialization_dict
1✔
182

183
    @classmethod
1✔
184
    def from_dict(cls, data: Dict[str, Any]) -> "TransformersSimilarityRanker":
1✔
185
        """
186
        Deserializes the component from a dictionary.
187

188
        :param data:
189
            Dictionary to deserialize from.
190
        :returns:
191
            Deserialized component.
192
        """
193
        init_params = data["init_parameters"]
1✔
194
        if init_params.get("device") is not None:
1✔
195
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
×
196
        if init_params.get("model_kwargs") is not None:
1✔
197
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
1✔
198
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
199

200
        return default_from_dict(cls, data)
1✔
201

202
    @component.output_types(documents=List[Document])
1✔
203
    def run(  # pylint: disable=too-many-positional-arguments
1✔
204
        self,
205
        query: str,
206
        documents: List[Document],
207
        top_k: Optional[int] = None,
208
        scale_score: Optional[bool] = None,
209
        calibration_factor: Optional[float] = None,
210
        score_threshold: Optional[float] = None,
211
    ):
212
        """
213
        Returns a list of documents ranked by their similarity to the given query.
214

215
        :param query:
216
            The input query to compare the documents to.
217
        :param documents:
218
            A list of documents to be ranked.
219
        :param top_k:
220
            The maximum number of documents to return.
221
        :param scale_score:
222
            If `True`, scales the raw logit predictions using a Sigmoid activation function.
223
            If `False`, disables scaling of the raw logit predictions.
224
        :param calibration_factor:
225
            Use this factor to calibrate probabilities with `sigmoid(logits * calibration_factor)`.
226
            Used only if `scale_score` is `True`.
227
        :param score_threshold:
228
            Use it to return documents only with a score above this threshold.
229
        :returns:
230
            A dictionary with the following keys:
231
            - `documents`: A list of documents closest to the query, sorted from most similar to least similar.
232

233
        :raises ValueError:
234
            If `top_k` is not > 0.
235
            If `scale_score` is True and `calibration_factor` is not provided.
236
        :raises RuntimeError:
237
            If the model is not loaded because `warm_up()` was not called before.
238
        """
239
        # If a model path is provided but the model isn't loaded
240
        if self.model is None:
1✔
241
            raise RuntimeError(
1✔
242
                "The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'."
243
            )
244

245
        if not documents:
1✔
246
            return {"documents": []}
1✔
247

248
        top_k = top_k or self.top_k
1✔
249
        scale_score = scale_score or self.scale_score
1✔
250
        calibration_factor = calibration_factor or self.calibration_factor
1✔
251
        score_threshold = score_threshold or self.score_threshold
1✔
252

253
        if top_k <= 0:
1✔
254
            raise ValueError(f"top_k must be > 0, but got {top_k}")
×
255

256
        if scale_score and calibration_factor is None:
1✔
257
            raise ValueError(
×
258
                f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
259
            )
260

261
        query_doc_pairs = []
1✔
262
        for doc in documents:
1✔
263
            meta_values_to_embed = [
1✔
264
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
265
            ]
266
            text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
1✔
267
            query_doc_pairs.append([self.query_prefix + query, self.document_prefix + text_to_embed])
1✔
268

269
        class _Dataset(Dataset):
1✔
270
            def __init__(self, batch_encoding):
1✔
271
                self.batch_encoding = batch_encoding
1✔
272

273
            def __len__(self):
1✔
274
                return len(self.batch_encoding["input_ids"])
1✔
275

276
            def __getitem__(self, item):
1✔
277
                return {key: self.batch_encoding.data[key][item] for key in self.batch_encoding.data.keys()}
×
278

279
        batch_enc = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to(  # type: ignore
1✔
280
            self.device.first_device.to_torch()
281
        )
282
        dataset = _Dataset(batch_enc)
1✔
283
        inp_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
1✔
284

285
        similarity_scores = []
1✔
286
        with torch.inference_mode():
1✔
287
            for features in inp_dataloader:
1✔
288
                model_preds = self.model(**features).logits.squeeze(dim=1)  # type: ignore
×
289
                similarity_scores.extend(model_preds)
×
290
        similarity_scores = torch.stack(similarity_scores)
1✔
291

292
        if scale_score:
1✔
293
            similarity_scores = torch.sigmoid(similarity_scores * calibration_factor)
1✔
294

295
        _, sorted_indices = torch.sort(similarity_scores, descending=True)
1✔
296

297
        sorted_indices = sorted_indices.cpu().tolist()  # type: ignore
1✔
298
        similarity_scores = similarity_scores.cpu().tolist()
1✔
299
        ranked_docs = []
1✔
300
        for sorted_index in sorted_indices:
1✔
301
            i = sorted_index
1✔
302
            documents[i].score = similarity_scores[i]
1✔
303
            ranked_docs.append(documents[i])
1✔
304

305
        if score_threshold is not None:
1✔
306
            ranked_docs = [doc for doc in ranked_docs if doc.score >= score_threshold]
1✔
307

308
        return {"documents": ranked_docs[:top_k]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc