• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15131674881

20 May 2025 07:35AM UTC coverage: 90.156% (-0.3%) from 90.471%
15131674881

Pull #9407

github

web-flow
Merge b382eca10 into 6ad23f822
Pull Request #9407: feat: stream `ToolResult` from run_async in Agent

10972 of 12170 relevant lines covered (90.16%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.84
haystack/components/rankers/sentence_transformers_diversity.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from enum import Enum
1✔
6
from typing import Any, Dict, List, Literal, Optional, Union
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
11
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
1✔
12

13
with LazyImport(message="Run 'pip install \"sentence-transformers>=3.0.0\"'") as torch_and_sentence_transformers_import:
1✔
14
    import torch
1✔
15
    from sentence_transformers import SentenceTransformer
1✔
16

17

18
class DiversityRankingStrategy(Enum):
1✔
19
    """
20
    The strategy to use for diversity ranking.
21
    """
22

23
    GREEDY_DIVERSITY_ORDER = "greedy_diversity_order"
1✔
24
    MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance"
1✔
25

26
    def __str__(self) -> str:
1✔
27
        """
28
        Convert a Strategy enum to a string.
29
        """
30
        return self.value
1✔
31

32
    @staticmethod
1✔
33
    def from_str(string: str) -> "DiversityRankingStrategy":
1✔
34
        """
35
        Convert a string to a Strategy enum.
36
        """
37
        enum_map = {e.value: e for e in DiversityRankingStrategy}
1✔
38
        strategy = enum_map.get(string)
1✔
39
        if strategy is None:
1✔
40
            msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}"
1✔
41
            raise ValueError(msg)
1✔
42
        return strategy
1✔
43

44

45
class DiversityRankingSimilarity(Enum):
1✔
46
    """
47
    The similarity metric to use for comparing embeddings.
48
    """
49

50
    DOT_PRODUCT = "dot_product"
1✔
51
    COSINE = "cosine"
1✔
52

53
    def __str__(self) -> str:
1✔
54
        """
55
        Convert a Similarity enum to a string.
56
        """
57
        return self.value
1✔
58

59
    @staticmethod
1✔
60
    def from_str(string: str) -> "DiversityRankingSimilarity":
1✔
61
        """
62
        Convert a string to a Similarity enum.
63
        """
64
        enum_map = {e.value: e for e in DiversityRankingSimilarity}
1✔
65
        similarity = enum_map.get(string)
1✔
66
        if similarity is None:
1✔
67
            msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}"
1✔
68
            raise ValueError(msg)
1✔
69
        return similarity
1✔
70

71

72
@component
1✔
73
class SentenceTransformersDiversityRanker:
1✔
74
    """
75
    A Diversity Ranker based on Sentence Transformers.
76

77
    Applies a document ranking algorithm based on one of the two strategies:
78

79
    1. Greedy Diversity Order:
80

81
        Implements a document ranking algorithm that orders documents in a way that maximizes the overall diversity
82
        of the documents based on their similarity to the query.
83

84
        It uses a pre-trained Sentence Transformers model to embed the query and
85
        the documents.
86

87
    2. Maximum Margin Relevance:
88

89
        Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR)
90
        scores.
91

92
        MMR scores are calculated for each document based on their relevance to the query and diversity from already
93
        selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between
94
        relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the
95
        trade-off between relevance and diversity.
96

97
    ### Usage example
98
    ```python
99
    from haystack import Document
100
    from haystack.components.rankers import SentenceTransformersDiversityRanker
101

102
    ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order")
103
    ranker.warm_up()
104

105
    docs = [Document(content="Paris"), Document(content="Berlin")]
106
    query = "What is the capital of germany?"
107
    output = ranker.run(query=query, documents=docs)
108
    docs = output["documents"]
109
    ```
110
    """  # noqa: E501
111

112
    def __init__(  # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
1✔
113
        self,
114
        model: str = "sentence-transformers/all-MiniLM-L6-v2",
115
        top_k: int = 10,
116
        device: Optional[ComponentDevice] = None,
117
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
118
        similarity: Union[str, DiversityRankingSimilarity] = "cosine",
119
        query_prefix: str = "",
120
        query_suffix: str = "",
121
        document_prefix: str = "",
122
        document_suffix: str = "",
123
        meta_fields_to_embed: Optional[List[str]] = None,
124
        embedding_separator: str = "\n",
125
        strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
126
        lambda_threshold: float = 0.5,
127
        model_kwargs: Optional[Dict[str, Any]] = None,
128
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
129
        config_kwargs: Optional[Dict[str, Any]] = None,
130
        backend: Literal["torch", "onnx", "openvino"] = "torch",
131
    ):
132
        """
133
        Initialize a SentenceTransformersDiversityRanker.
134

135
        :param model: Local path or name of the model in Hugging Face's model hub,
136
            such as `'sentence-transformers/all-MiniLM-L6-v2'`.
137
        :param top_k: The maximum number of Documents to return per query.
138
        :param device: The device on which the model is loaded. If `None`, the default device is automatically
139
            selected.
140
        :param token: The API token used to download private models from Hugging Face.
141
        :param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or
142
            "cosine".
143
        :param query_prefix: A string to add to the beginning of the query text before ranking.
144
            Can be used to prepend the text with an instruction, as required by some embedding models,
145
            such as E5 and BGE.
146
        :param query_suffix: A string to add to the end of the query text before ranking.
147
        :param document_prefix: A string to add to the beginning of each Document text before ranking.
148
            Can be used to prepend the text with an instruction, as required by some embedding models,
149
            such as E5 and BGE.
150
        :param document_suffix: A string to add to the end of each Document text before ranking.
151
        :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
152
        :param embedding_separator: Separator used to concatenate the meta fields to the Document content.
153
        :param strategy: The strategy to use for diversity ranking. Can be either "greedy_diversity_order" or
154
                         "maximum_margin_relevance".
155
        :param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
156
                                 "maximum_margin_relevance".
157
        :param model_kwargs:
158
            Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
159
            when loading the model. Refer to specific model documentation for available kwargs.
160
        :param tokenizer_kwargs:
161
            Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
162
            Refer to specific model documentation for available kwargs.
163
        :param config_kwargs:
164
            Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
165
        :param backend:
166
            The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
167
            Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
168
            for more information on acceleration and quantization options.
169
        """
170
        torch_and_sentence_transformers_import.check()
1✔
171

172
        self.model_name_or_path = model
1✔
173
        if top_k is None or top_k <= 0:
1✔
174
            raise ValueError(f"top_k must be > 0, but got {top_k}")
1✔
175
        self.top_k = top_k
1✔
176
        self.device = ComponentDevice.resolve_device(device)
1✔
177
        self.token = token
1✔
178
        self.model: Optional[SentenceTransformer] = None
1✔
179
        self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity
1✔
180
        self.query_prefix = query_prefix
1✔
181
        self.document_prefix = document_prefix
1✔
182
        self.query_suffix = query_suffix
1✔
183
        self.document_suffix = document_suffix
1✔
184
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
185
        self.embedding_separator = embedding_separator
1✔
186
        self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
1✔
187
        self.lambda_threshold = lambda_threshold or 0.5
1✔
188
        self._check_lambda_threshold(self.lambda_threshold, self.strategy)
1✔
189
        self.model_kwargs = model_kwargs
1✔
190
        self.tokenizer_kwargs = tokenizer_kwargs
1✔
191
        self.config_kwargs = config_kwargs
1✔
192
        self.backend = backend
1✔
193

194
    def warm_up(self):
1✔
195
        """
196
        Initializes the component.
197
        """
198
        if self.model is None:
1✔
199
            self.model = SentenceTransformer(
1✔
200
                model_name_or_path=self.model_name_or_path,
201
                device=self.device.to_torch_str(),
202
                token=self.token.resolve_value() if self.token else None,
203
                model_kwargs=self.model_kwargs,
204
                tokenizer_kwargs=self.tokenizer_kwargs,
205
                config_kwargs=self.config_kwargs,
206
                backend=self.backend,
207
            )
208

209
    def to_dict(self) -> Dict[str, Any]:
1✔
210
        """
211
        Serializes the component to a dictionary.
212

213
        :returns:
214
            Dictionary with serialized data.
215
        """
216
        serialization_dict = default_to_dict(
1✔
217
            self,
218
            model=self.model_name_or_path,
219
            top_k=self.top_k,
220
            device=self.device.to_dict(),
221
            token=self.token.to_dict() if self.token else None,
222
            similarity=str(self.similarity),
223
            query_prefix=self.query_prefix,
224
            query_suffix=self.query_suffix,
225
            document_prefix=self.document_prefix,
226
            document_suffix=self.document_suffix,
227
            meta_fields_to_embed=self.meta_fields_to_embed,
228
            embedding_separator=self.embedding_separator,
229
            strategy=str(self.strategy),
230
            lambda_threshold=self.lambda_threshold,
231
            model_kwargs=self.model_kwargs,
232
            tokenizer_kwargs=self.tokenizer_kwargs,
233
            config_kwargs=self.config_kwargs,
234
            backend=self.backend,
235
        )
236
        if serialization_dict["init_parameters"].get("model_kwargs") is not None:
1✔
237
            serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
×
238
        return serialization_dict
1✔
239

240
    @classmethod
1✔
241
    def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
1✔
242
        """
243
        Deserializes the component from a dictionary.
244

245
        :param data:
246
            The dictionary to deserialize from.
247
        :returns:
248
            The deserialized component.
249
        """
250
        init_params = data["init_parameters"]
1✔
251
        if init_params.get("device") is not None:
1✔
252
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
253
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
254
        if init_params.get("model_kwargs") is not None:
1✔
255
            deserialize_hf_model_kwargs(init_params["model_kwargs"])
×
256
        return default_from_dict(cls, data)
1✔
257

258
    def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
1✔
259
        """
260
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
261
        """
262
        texts_to_embed = []
1✔
263
        for doc in documents:
1✔
264
            meta_values_to_embed = [
1✔
265
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
266
            ]
267
            text_to_embed = (
1✔
268
                self.document_prefix
269
                + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
270
                + self.document_suffix
271
            )
272
            texts_to_embed.append(text_to_embed)
1✔
273

274
        return texts_to_embed
1✔
275

276
    def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
1✔
277
        """
278
        Orders the given list of documents to maximize diversity.
279

280
        The algorithm first calculates embeddings for each document and the query. It starts by selecting the document
281
        that is semantically closest to the query. Then, for each remaining document, it selects the one that, on
282
        average, is least similar to the already selected documents. This process continues until all documents are
283
        selected, resulting in a list where each subsequent document contributes the most to the overall diversity of
284
        the selected set.
285

286
        :param query: The search query.
287
        :param documents: The list of Document objects to be ranked.
288

289
        :return: A list of documents ordered to maximize diversity.
290
        """
291
        texts_to_embed = self._prepare_texts_to_embed(documents)
1✔
292

293
        doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
1✔
294

295
        n = len(documents)
1✔
296
        selected: List[int] = []
1✔
297

298
        # Compute the similarity vector between the query and documents
299
        query_doc_sim = query_embedding @ doc_embeddings.T
1✔
300

301
        # Start with the document with the highest similarity to the query
302
        selected.append(int(torch.argmax(query_doc_sim).item()))
1✔
303

304
        selected_sum = doc_embeddings[selected[0]] / n
1✔
305

306
        while len(selected) < n:
1✔
307
            # Compute mean of dot products of all selected documents and all other documents
308
            similarities = selected_sum @ doc_embeddings.T
1✔
309
            # Mask documents that are already selected
310
            similarities[selected] = torch.inf
1✔
311
            # Select the document with the lowest total similarity score
312
            index_unselected = int(torch.argmin(similarities).item())
1✔
313
            selected.append(index_unselected)
1✔
314
            # It's enough just to add to the selected vectors because dot product is distributive
315
            # It's divided by n for numerical stability
316
            selected_sum += doc_embeddings[index_unselected] / n
1✔
317

318
        ranked_docs: List[Document] = [documents[i] for i in selected]
1✔
319

320
        return ranked_docs
1✔
321

322
    def _embed_and_normalize(self, query, texts_to_embed):
1✔
323
        assert self.model is not None  # verified in run but mypy doesn't see it
1✔
324

325
        # Calculate embeddings
326
        doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True)
1✔
327
        query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True)
1✔
328

329
        # Normalize embeddings to unit length for computing cosine similarity
330
        if self.similarity == DiversityRankingSimilarity.COSINE:
1✔
331
            doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
1✔
332
            query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
1✔
333
        return doc_embeddings, query_embedding
1✔
334

335
    def _maximum_margin_relevance(
1✔
336
        self, query: str, documents: List[Document], lambda_threshold: float, top_k: int
337
    ) -> List[Document]:
338
        """
339
        Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores.
340

341
        MMR scores are calculated for each document based on their relevance to the query and diversity from already
342
        selected documents.
343

344
        The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query
345
        and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance
346
        and diversity.
347

348
        A closer value to 0 favors diversity, while a closer value to 1 favors relevance to the query.
349

350
        See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries"
351
               https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
352
        """
353

354
        texts_to_embed = self._prepare_texts_to_embed(documents)
1✔
355
        doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
1✔
356
        top_k = top_k if top_k else len(documents)
1✔
357

358
        selected: List[int] = []
1✔
359
        query_similarities_as_tensor = query_embedding @ doc_embeddings.T
1✔
360
        query_similarities = query_similarities_as_tensor.reshape(-1)
1✔
361
        idx = int(torch.argmax(query_similarities))
1✔
362
        selected.append(idx)
1✔
363
        while len(selected) < top_k:
1✔
364
            best_idx = None
1✔
365
            best_score = -float("inf")
1✔
366
            for idx, _ in enumerate(documents):
1✔
367
                if idx in selected:
1✔
368
                    continue
1✔
369
                relevance_score = query_similarities[idx]
1✔
370
                diversity_score = max(
1✔
371
                    doc_embeddings[idx] @ doc_embeddings[j].permute(*torch.arange(doc_embeddings[j].ndim - 1, -1, -1))
372
                    for j in selected
373
                )
374
                mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score
1✔
375
                if mmr_score > best_score:
1✔
376
                    best_score = mmr_score
1✔
377
                    best_idx = idx
1✔
378
            if best_idx is None:
1✔
379
                raise ValueError("No best document found, check if the documents list contains any documents.")
×
380
            selected.append(best_idx)
1✔
381

382
        return [documents[i] for i in selected]
1✔
383

384
    @staticmethod
1✔
385
    def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy):
1✔
386
        if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
1✔
387
            raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")
×
388

389
    @component.output_types(documents=List[Document])
1✔
390
    def run(
1✔
391
        self,
392
        query: str,
393
        documents: List[Document],
394
        top_k: Optional[int] = None,
395
        lambda_threshold: Optional[float] = None,
396
    ) -> Dict[str, List[Document]]:
397
        """
398
        Rank the documents based on their diversity.
399

400
        :param query: The search query.
401
        :param documents: List of Document objects to be ranker.
402
        :param top_k: Optional. An integer to override the top_k set during initialization.
403
        :param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when
404
                                strategy is "maximum_margin_relevance".
405

406
        :returns: A dictionary with the following key:
407
            - `documents`: List of Document objects that have been selected based on the diversity ranking.
408

409
        :raises ValueError: If the top_k value is less than or equal to 0.
410
        :raises RuntimeError: If the component has not been warmed up.
411
        """
412
        if self.model is None:
1✔
413
            error_msg = (
1✔
414
                "The component SentenceTransformersDiversityRanker wasn't warmed up. "
415
                "Run 'warm_up()' before calling 'run()'."
416
            )
417
            raise RuntimeError(error_msg)
1✔
418

419
        if not documents:
1✔
420
            return {"documents": []}
1✔
421

422
        if top_k is None:
1✔
423
            top_k = self.top_k
1✔
424
        elif not 0 < top_k <= len(documents):
1✔
425
            raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
1✔
426

427
        if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
1✔
428
            if lambda_threshold is None:
×
429
                lambda_threshold = self.lambda_threshold
×
430
            self._check_lambda_threshold(lambda_threshold, self.strategy)
×
431
            re_ranked_docs = self._maximum_margin_relevance(
×
432
                query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k
433
            )
434
        else:
435
            re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents)
1✔
436

437
        return {"documents": re_ranked_docs[:top_k]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc