• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13634803133

03 Mar 2025 03:47PM UTC coverage: 90.124% (+0.1%) from 89.986%
13634803133

Pull #8906

github

web-flow
Merge e48e49114 into 1b2053b35
Pull Request #8906: refactor!: remove `dataframe` field from `Document` and `ExtractedTableAnswer`; make `pandas` optional

9536 of 10581 relevant lines covered (90.12%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.95
haystack/document_stores/in_memory/document_store.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
import math
1✔
7
import re
1✔
8
import uuid
1✔
9
from collections import Counter
1✔
10
from dataclasses import dataclass
1✔
11
from pathlib import Path
1✔
12
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
1✔
13

14
import numpy as np
1✔
15

16
from haystack import default_from_dict, default_to_dict, logging
1✔
17
from haystack.dataclasses import Document
1✔
18
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1✔
19
from haystack.document_stores.types import DuplicatePolicy
1✔
20
from haystack.utils import expit
1✔
21
from haystack.utils.filters import document_matches_filter
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
26
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
27
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
28
# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with
29
# BM25_SCALING_FACTOR=2 but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically.
30
# Increase the default if most unscaled scores are larger than expected (>30) and otherwise would incorrectly all be
31
# mapped to scores ~1.
32
BM25_SCALING_FACTOR = 8
1✔
33
DOT_PRODUCT_SCALING_FACTOR = 100
1✔
34

35

36
@dataclass
1✔
37
class BM25DocumentStats:
1✔
38
    """
39
    A dataclass for managing document statistics for BM25 retrieval.
40

41
    :param freq_token: A Counter of token frequencies in the document.
42
    :param doc_len: Number of tokens in the document.
43
    """
44

45
    freq_token: Dict[str, int]
1✔
46
    doc_len: int
1✔
47

48

49
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
50
_STORAGES: Dict[str, Dict[str, Document]] = {}
1✔
51
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
1✔
52
_AVERAGE_DOC_LEN_STORAGES: Dict[str, float] = {}
1✔
53
_FREQ_VOCAB_FOR_IDF_STORAGES: Dict[str, Counter] = {}
1✔
54

55

56
class InMemoryDocumentStore:
1✔
57
    """
58
    Stores data in-memory. It's ephemeral and cannot be saved to disk.
59
    """
60

61
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
62
        self,
63
        bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
64
        bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
65
        bm25_parameters: Optional[Dict] = None,
66
        embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
67
        index: Optional[str] = None,
68
    ):
69
        """
70
        Initializes the DocumentStore.
71

72
        :param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
73
        :param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
74
        :param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
75
            For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
76
            You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25.
77
        :param embedding_similarity_function: The similarity function used to compare Documents embeddings.
78
            One of "dot_product" (default) or "cosine". To choose the most appropriate function, look for information
79
            about your embedding model.
80
        :param index: A specific index to store the documents. If not specified, a random UUID is used.
81
            Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
82
        """
83
        self.bm25_tokenization_regex = bm25_tokenization_regex
1✔
84
        self.tokenizer = re.compile(bm25_tokenization_regex).findall
1✔
85

86
        if index is None:
1✔
87
            index = str(uuid.uuid4())
1✔
88

89
        self.index = index
1✔
90
        if self.index not in _STORAGES:
1✔
91
            _STORAGES[self.index] = {}
1✔
92

93
        self.bm25_algorithm = bm25_algorithm
1✔
94
        self.bm25_algorithm_inst = self._dispatch_bm25()
1✔
95
        self.bm25_parameters = bm25_parameters or {}
1✔
96
        self.embedding_similarity_function = embedding_similarity_function
1✔
97

98
        # Per-document statistics
99
        if self.index not in _BM25_STATS_STORAGES:
1✔
100
            _BM25_STATS_STORAGES[self.index] = {}
1✔
101

102
        if self.index not in _AVERAGE_DOC_LEN_STORAGES:
1✔
103
            _AVERAGE_DOC_LEN_STORAGES[self.index] = 0.0
1✔
104

105
        if self.index not in _FREQ_VOCAB_FOR_IDF_STORAGES:
1✔
106
            _FREQ_VOCAB_FOR_IDF_STORAGES[self.index] = Counter()
1✔
107

108
    @property
1✔
109
    def storage(self) -> Dict[str, Document]:
1✔
110
        """
111
        Utility property that returns the storage used by this instance of InMemoryDocumentStore.
112
        """
113
        return _STORAGES.get(self.index, {})
1✔
114

115
    @property
1✔
116
    def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
1✔
117
        return _BM25_STATS_STORAGES.get(self.index, {})
1✔
118

119
    @property
1✔
120
    def _avg_doc_len(self) -> float:
1✔
121
        return _AVERAGE_DOC_LEN_STORAGES.get(self.index, 0.0)
1✔
122

123
    @_avg_doc_len.setter
1✔
124
    def _avg_doc_len(self, value: float):
1✔
125
        _AVERAGE_DOC_LEN_STORAGES[self.index] = value
1✔
126

127
    @property
1✔
128
    def _freq_vocab_for_idf(self) -> Counter:
1✔
129
        return _FREQ_VOCAB_FOR_IDF_STORAGES.get(self.index, Counter())
1✔
130

131
    def _dispatch_bm25(self):
1✔
132
        """
133
        Select the correct BM25 algorithm based on user specification.
134

135
        :returns:
136
            The BM25 algorithm method.
137
        """
138
        table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}
1✔
139

140
        if self.bm25_algorithm not in table:
1✔
141
            raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
1✔
142
        return table[self.bm25_algorithm]
1✔
143

144
    def _tokenize_bm25(self, text: str) -> List[str]:
1✔
145
        """
146
        Tokenize text using the BM25 tokenization regex.
147

148
        Here we explicitly create a tokenization method to encapsulate
149
        all pre-processing logic used to create BM25 tokens, such as
150
        lowercasing. This helps track the exact tokenization process
151
        used for BM25 scoring at any given time.
152

153
        :param text:
154
            The text to tokenize.
155
        :returns:
156
            A list of tokens.
157
        """
158
        text = text.lower()
1✔
159
        return self.tokenizer(text)
1✔
160

161
    def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
162
        """
163
        Calculate BM25L scores for the given query and filtered documents.
164

165
        :param query:
166
            The query string.
167
        :param documents:
168
            The list of documents to score, should be produced by
169
            the filter_documents method; may be an empty list.
170
        :returns:
171
            A list of tuples, each containing a Document and its BM25L score.
172
        """
173
        k = self.bm25_parameters.get("k1", 1.5)
1✔
174
        b = self.bm25_parameters.get("b", 0.75)
1✔
175
        delta = self.bm25_parameters.get("delta", 0.5)
1✔
176

177
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
178
            """Per-token IDF computation for all tokens."""
179
            idf = {}
1✔
180
            n_corpus = len(self._bm25_attr)
1✔
181
            for tok in tokens:
1✔
182
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
183
                idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
1✔
184
            return idf
1✔
185

186
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
187
            """Per-token BM25L computation."""
188
            freq_term = freq.get(token, 0.0)
1✔
189
            ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
1✔
190
            return (1.0 + k) * (ctd + delta) / (k + ctd + delta)
1✔
191

192
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
193
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
194

195
        ret = []
1✔
196
        for doc in documents:
1✔
197
            doc_stats = bm25_attr[doc.id]
1✔
198
            freq = doc_stats.freq_token
1✔
199
            doc_len = doc_stats.doc_len
1✔
200

201
            score = 0.0
1✔
202
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
203
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
204
            ret.append((doc, score))
1✔
205

206
        return ret
1✔
207

208
    def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
209
        """
210
        Calculate BM25Okapi scores for the given query and filtered documents.
211

212
        :param query:
213
            The query string.
214
        :param documents:
215
            The list of documents to score, should be produced by
216
            the filter_documents method; may be an empty list.
217
        :returns:
218
            A list of tuples, each containing a Document and its BM25L score.
219
        """
220
        k = self.bm25_parameters.get("k1", 1.5)
1✔
221
        b = self.bm25_parameters.get("b", 0.75)
1✔
222
        epsilon = self.bm25_parameters.get("epsilon", 0.25)
1✔
223

224
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
225
            """Per-token IDF computation for all tokens."""
226
            sum_idf = 0.0
1✔
227
            neg_idf_tokens = []
1✔
228

229
            # Although this is a global statistic, we compute it here
230
            # to make the computation more self-contained. And the
231
            # complexity is O(vocab_size), which is acceptable.
232
            idf = {}
1✔
233
            for tok, n in self._freq_vocab_for_idf.items():
1✔
234
                idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
1✔
235
                sum_idf += idf[tok]
1✔
236
                if idf[tok] < 0:
1✔
237
                    neg_idf_tokens.append(tok)
1✔
238

239
            eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
1✔
240
            for tok in neg_idf_tokens:
1✔
241
                idf[tok] = eps
1✔
242
            return {tok: idf.get(tok, 0.0) for tok in tokens}
1✔
243

244
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
245
            """Per-token BM25L computation."""
246
            freq_term = freq.get(token, 0.0)
1✔
247
            freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
248
            return freq_term * (1.0 + k) / freq_norm
1✔
249

250
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
251
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
252

253
        ret = []
1✔
254
        for doc in documents:
1✔
255
            doc_stats = bm25_attr[doc.id]
1✔
256
            freq = doc_stats.freq_token
1✔
257
            doc_len = doc_stats.doc_len
1✔
258

259
            score = 0.0
1✔
260
            for tok in idf.keys():
1✔
261
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
262
            ret.append((doc, score))
1✔
263

264
        return ret
1✔
265

266
    def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
267
        """
268
        Calculate BM25+ scores for the given query and filtered documents.
269

270
        This implementation follows the document on BM25 Wikipedia page,
271
        which add 1 (smoothing factor) to document frequency when computing IDF.
272

273
        :param query:
274
            The query string.
275
        :param documents:
276
            The list of documents to score, should be produced by
277
            the filter_documents method; may be an empty list.
278
        :returns:
279
            A list of tuples, each containing a Document and its BM25+ score.
280
        """
281
        k = self.bm25_parameters.get("k1", 1.5)
1✔
282
        b = self.bm25_parameters.get("b", 0.75)
1✔
283
        delta = self.bm25_parameters.get("delta", 1.0)
1✔
284

285
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
286
            """Per-token IDF computation."""
287
            idf = {}
1✔
288
            n_corpus = len(self._bm25_attr)
1✔
289
            for tok in tokens:
1✔
290
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
291
                idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
1✔
292
            return idf
1✔
293

294
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
1✔
295
            """Per-token normalized term frequency."""
296
            freq_term = freq.get(token, 0.0)
1✔
297
            freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
298
            return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta
1✔
299

300
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
301
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
302

303
        ret = []
1✔
304
        for doc in documents:
1✔
305
            doc_stats = bm25_attr[doc.id]
1✔
306
            freq = doc_stats.freq_token
1✔
307
            doc_len = doc_stats.doc_len
1✔
308

309
            score = 0.0
1✔
310
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
311
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
312
            ret.append((doc, score))
1✔
313

314
        return ret
1✔
315

316
    def to_dict(self) -> Dict[str, Any]:
1✔
317
        """
318
        Serializes the component to a dictionary.
319

320
        :returns:
321
            Dictionary with serialized data.
322
        """
323
        return default_to_dict(
1✔
324
            self,
325
            bm25_tokenization_regex=self.bm25_tokenization_regex,
326
            bm25_algorithm=self.bm25_algorithm,
327
            bm25_parameters=self.bm25_parameters,
328
            embedding_similarity_function=self.embedding_similarity_function,
329
            index=self.index,
330
        )
331

332
    @classmethod
1✔
333
    def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore":
1✔
334
        """
335
        Deserializes the component from a dictionary.
336

337
        :param data:
338
            The dictionary to deserialize from.
339
        :returns:
340
            The deserialized component.
341
        """
342
        return default_from_dict(cls, data)
1✔
343

344
    def save_to_disk(self, path: str) -> None:
1✔
345
        """
346
        Write the database and its' data to disk as a JSON file.
347

348
        :param path: The path to the JSON file.
349
        """
350
        data: Dict[str, Any] = self.to_dict()
1✔
351
        data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
1✔
352
        with open(path, "w") as f:
1✔
353
            json.dump(data, f)
1✔
354

355
    @classmethod
1✔
356
    def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
1✔
357
        """
358
        Load the database and its' data from disk as a JSON file.
359

360
        :param path: The path to the JSON file.
361
        :returns: The loaded InMemoryDocumentStore.
362
        """
363
        if Path(path).exists():
1✔
364
            try:
1✔
365
                with open(path, "r") as f:
1✔
366
                    data = json.load(f)
1✔
367
            except Exception as e:
×
368
                raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
×
369

370
            documents = data.pop("documents")
1✔
371
            cls_object = default_from_dict(cls, data)
1✔
372
            cls_object.write_documents(
1✔
373
                documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
374
            )
375
            return cls_object
1✔
376

377
        else:
378
            raise FileNotFoundError(f"File {path} not found.")
×
379

380
    def count_documents(self) -> int:
1✔
381
        """
382
        Returns the number of how many documents are present in the DocumentStore.
383
        """
384
        return len(self.storage.keys())
1✔
385

386
    def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
1✔
387
        """
388
        Returns the documents that match the filters provided.
389

390
        For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol
391
        documentation.
392

393
        :param filters: The filters to apply to the document list.
394
        :returns: A list of Documents that match the given filters.
395
        """
396
        if filters:
1✔
397
            if "operator" not in filters and "conditions" not in filters:
1✔
398
                raise ValueError(
×
399
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
400
                )
401
            return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)]
1✔
402
        return list(self.storage.values())
1✔
403

404
    def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
1✔
405
        """
406
        Refer to the DocumentStore.write_documents() protocol documentation.
407

408
        If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
409
        """
410
        if (
1✔
411
            not isinstance(documents, Iterable)
412
            or isinstance(documents, str)
413
            or any(not isinstance(doc, Document) for doc in documents)
414
        ):
415
            raise ValueError("Please provide a list of Documents.")
1✔
416

417
        if policy == DuplicatePolicy.NONE:
1✔
418
            policy = DuplicatePolicy.FAIL
1✔
419

420
        written_documents = len(documents)
1✔
421
        for document in documents:
1✔
422
            if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
1✔
423
                if policy == DuplicatePolicy.FAIL:
1✔
424
                    raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
1✔
425
                if policy == DuplicatePolicy.SKIP:
1✔
426
                    logger.warning("ID '{document_id}' already exists", document_id=document.id)
1✔
427
                    written_documents -= 1
1✔
428
                    continue
1✔
429

430
            # Since the statistics are updated in an incremental manner,
431
            # we need to explicitly remove the existing document to revert
432
            # the statistics before updating them with the new document.
433
            if document.id in self.storage.keys():
1✔
434
                self.delete_documents([document.id])
1✔
435

436
            tokens = []
1✔
437
            if document.content is not None:
1✔
438
                tokens = self._tokenize_bm25(document.content)
1✔
439

440
            self.storage[document.id] = document
1✔
441

442
            self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
1✔
443
            self._freq_vocab_for_idf.update(set(tokens))
1✔
444
            self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
1✔
445
        return written_documents
1✔
446

447
    def delete_documents(self, document_ids: List[str]) -> None:
1✔
448
        """
449
        Deletes all documents with matching document_ids from the DocumentStore.
450

451
        :param document_ids: The object_ids to delete.
452
        """
453
        for doc_id in document_ids:
1✔
454
            if doc_id not in self.storage.keys():
1✔
455
                continue
1✔
456
            del self.storage[doc_id]
1✔
457

458
            # Update statistics accordingly
459
            doc_stats = self._bm25_attr.pop(doc_id)
1✔
460
            freq = doc_stats.freq_token
1✔
461
            doc_len = doc_stats.doc_len
1✔
462

463
            self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
1✔
464
            try:
1✔
465
                self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
1✔
466
            except ZeroDivisionError:
1✔
467
                self._avg_doc_len = 0
1✔
468

469
    def bm25_retrieval(
1✔
470
        self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
471
    ) -> List[Document]:
472
        """
473
        Retrieves documents that are most relevant to the query using BM25 algorithm.
474

475
        :param query: The query string.
476
        :param filters: A dictionary with filters to narrow down the search space.
477
        :param top_k: The number of top documents to retrieve. Default is 10.
478
        :param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
479
        :returns: A list of the top_k documents most relevant to the query.
480
        """
481
        if not query:
1✔
482
            raise ValueError("Query should be a non-empty string")
1✔
483

484
        content_type_filter = {"field": "content", "operator": "!=", "value": None}
1✔
485
        if filters:
1✔
486
            if "operator" not in filters:
×
487
                raise ValueError(
×
488
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
489
                )
490
            filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
×
491
        else:
492
            filters = content_type_filter
1✔
493

494
        all_documents = self.filter_documents(filters=filters)
1✔
495
        if len(all_documents) == 0:
1✔
496
            logger.info("No documents found for BM25 retrieval. Returning empty list.")
1✔
497
            return []
1✔
498

499
        results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]
1✔
500

501
        # BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
502
        # It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
503
        # see https://github.com/deepset-ai/haystack/pull/6889 for more context.
504
        negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score
1✔
505

506
        # Create documents with the BM25 score to return them
507
        return_documents = []
1✔
508
        for doc, score in results:
1✔
509
            if scale_score:
1✔
510
                score = expit(score / BM25_SCALING_FACTOR)
1✔
511

512
            if not negatives_are_valid and score <= 0.0:
1✔
513
                continue
1✔
514

515
            doc_fields = doc.to_dict()
1✔
516
            doc_fields["score"] = score
1✔
517
            return_document = Document.from_dict(doc_fields)
1✔
518
            return_documents.append(return_document)
1✔
519

520
        return return_documents
1✔
521

522
    def embedding_retrieval(  # pylint: disable=too-many-positional-arguments
1✔
523
        self,
524
        query_embedding: List[float],
525
        filters: Optional[Dict[str, Any]] = None,
526
        top_k: int = 10,
527
        scale_score: bool = False,
528
        return_embedding: bool = False,
529
    ) -> List[Document]:
530
        """
531
        Retrieves documents that are most similar to the query embedding using a vector similarity metric.
532

533
        :param query_embedding: Embedding of the query.
534
        :param filters: A dictionary with filters to narrow down the search space.
535
        :param top_k: The number of top documents to retrieve. Default is 10.
536
        :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
537
        :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
538
        :returns: A list of the top_k documents most relevant to the query.
539
        """
540
        if len(query_embedding) == 0 or not isinstance(query_embedding[0], float):
1✔
541
            raise ValueError("query_embedding should be a non-empty list of floats.")
1✔
542

543
        filters = filters or {}
1✔
544
        all_documents = self.filter_documents(filters=filters)
1✔
545

546
        documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None]
1✔
547
        if len(documents_with_embeddings) == 0:
1✔
548
            logger.warning(
1✔
549
                "No Documents found with embeddings. Returning empty list. "
550
                "To generate embeddings, use a DocumentEmbedder."
551
            )
552
            return []
1✔
553
        elif len(documents_with_embeddings) < len(all_documents):
1✔
554
            logger.info(
1✔
555
                "Skipping some Documents that don't have an embedding. To generate embeddings, use a DocumentEmbedder."
556
            )
557

558
        scores = self._compute_query_embedding_similarity_scores(
1✔
559
            embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score
560
        )
561

562
        # create Documents with the similarity score for the top k results
563
        top_documents = []
1✔
564
        for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]:
1✔
565
            doc_fields = doc.to_dict()
1✔
566
            doc_fields["score"] = score
1✔
567
            if return_embedding is False:
1✔
568
                doc_fields["embedding"] = None
1✔
569
            top_documents.append(Document.from_dict(doc_fields))
1✔
570

571
        return top_documents
1✔
572

573
    def _compute_query_embedding_similarity_scores(
1✔
574
        self, embedding: List[float], documents: List[Document], scale_score: bool = False
575
    ) -> List[float]:
576
        """
577
        Computes the similarity scores between the query embedding and the embeddings of the documents.
578

579
        :param embedding: Embedding of the query.
580
        :param documents: A list of Documents.
581
        :param scale_score: Whether to scale the scores of the Documents. Default is False.
582
        :returns: A list of scores.
583
        """
584

585
        query_embedding = np.array(embedding)
1✔
586
        if query_embedding.ndim == 1:
1✔
587
            query_embedding = np.expand_dims(a=query_embedding, axis=0)
1✔
588

589
        try:
1✔
590
            document_embeddings = np.array([doc.embedding for doc in documents])
1✔
591
        except ValueError as e:
1✔
592
            if "inhomogeneous shape" in str(e):
1✔
593
                raise DocumentStoreError(
1✔
594
                    "The embedding size of all Documents should be the same. "
595
                    "Please make sure that the Documents have been embedded with the same model."
596
                ) from e
597
            raise e
×
598
        if document_embeddings.ndim == 1:
1✔
599
            document_embeddings = np.expand_dims(a=document_embeddings, axis=0)
×
600

601
        if self.embedding_similarity_function == "cosine":
1✔
602
            # cosine similarity is a normed dot product
603
            query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True)
1✔
604
            document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True)
1✔
605

606
        try:
1✔
607
            scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist()
1✔
608
        except ValueError as e:
1✔
609
            if "shapes" in str(e) and "not aligned" in str(e):
1✔
610
                raise DocumentStoreError(
1✔
611
                    "The embedding size of the query should be the same as the embedding size of the Documents. "
612
                    "Please make sure that the query has been embedded with the same model as the Documents."
613
                ) from e
614
            raise e
×
615

616
        if scale_score:
1✔
617
            if self.embedding_similarity_function == "dot_product":
1✔
618
                scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores]
1✔
619
            elif self.embedding_similarity_function == "cosine":
×
620
                scores = [(score + 1) / 2 for score in scores]
×
621

622
        return scores
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc