• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13718273459

07 Mar 2025 10:01AM UTC coverage: 89.949% (-0.01%) from 89.961%
13718273459

push

github

web-flow
fix: cleaning up `InMemoryDocumentStore` executor when created inside the class (#8994)

* cleaning up executor when created inside the class

* adding missed tests

9683 of 10765 relevant lines covered (89.95%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.6
haystack/document_stores/in_memory/document_store.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import math
1✔
8
import re
1✔
9
import uuid
1✔
10
from collections import Counter
1✔
11
from concurrent.futures import ThreadPoolExecutor
1✔
12
from dataclasses import dataclass
1✔
13
from pathlib import Path
1✔
14
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
1✔
15

16
import numpy as np
1✔
17

18
from haystack import default_from_dict, default_to_dict, logging
1✔
19
from haystack.dataclasses import Document
1✔
20
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1✔
21
from haystack.document_stores.types import DuplicatePolicy
1✔
22
from haystack.utils import expit
1✔
23
from haystack.utils.filters import document_matches_filter
1✔
24

25
logger = logging.getLogger(__name__)
1✔
26

27
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
28
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
29
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
30
# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with
31
# BM25_SCALING_FACTOR=2 but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically.
32
# Increase the default if most unscaled scores are larger than expected (>30) and otherwise would incorrectly all be
33
# mapped to scores ~1.
34
BM25_SCALING_FACTOR = 8
1✔
35
DOT_PRODUCT_SCALING_FACTOR = 100
1✔
36

37

38
@dataclass
1✔
39
class BM25DocumentStats:
1✔
40
    """
41
    A dataclass for managing document statistics for BM25 retrieval.
42

43
    :param freq_token: A Counter of token frequencies in the document.
44
    :param doc_len: Number of tokens in the document.
45
    """
46

47
    freq_token: Dict[str, int]
1✔
48
    doc_len: int
1✔
49

50

51
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
52
_STORAGES: Dict[str, Dict[str, Document]] = {}
1✔
53
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
1✔
54
_AVERAGE_DOC_LEN_STORAGES: Dict[str, float] = {}
1✔
55
_FREQ_VOCAB_FOR_IDF_STORAGES: Dict[str, Counter] = {}
1✔
56

57

58
class InMemoryDocumentStore:
1✔
59
    """
60
    Stores data in-memory. It's ephemeral and cannot be saved to disk.
61
    """
62

63
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
64
        self,
65
        bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
66
        bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
67
        bm25_parameters: Optional[Dict] = None,
68
        embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
69
        index: Optional[str] = None,
70
        async_executor: Optional[ThreadPoolExecutor] = None,
71
    ):
72
        """
73
        Initializes the DocumentStore.
74

75
        :param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
76
        :param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
77
        :param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
78
            For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
79
            You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25.
80
        :param embedding_similarity_function: The similarity function used to compare Documents embeddings.
81
            One of "dot_product" (default) or "cosine". To choose the most appropriate function, look for information
82
            about your embedding model.
83
        :param index: A specific index to store the documents. If not specified, a random UUID is used.
84
            Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
85
        :param async_executor:
86
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded
87
            executor will be initialized and used.
88
        """
89
        self.bm25_tokenization_regex = bm25_tokenization_regex
1✔
90
        self.tokenizer = re.compile(bm25_tokenization_regex).findall
1✔
91

92
        if index is None:
1✔
93
            index = str(uuid.uuid4())
1✔
94

95
        self.index = index
1✔
96
        if self.index not in _STORAGES:
1✔
97
            _STORAGES[self.index] = {}
1✔
98

99
        self.bm25_algorithm = bm25_algorithm
1✔
100
        self.bm25_algorithm_inst = self._dispatch_bm25()
1✔
101
        self.bm25_parameters = bm25_parameters or {}
1✔
102
        self.embedding_similarity_function = embedding_similarity_function
1✔
103

104
        # Per-document statistics
105
        if self.index not in _BM25_STATS_STORAGES:
1✔
106
            _BM25_STATS_STORAGES[self.index] = {}
1✔
107

108
        if self.index not in _AVERAGE_DOC_LEN_STORAGES:
1✔
109
            _AVERAGE_DOC_LEN_STORAGES[self.index] = 0.0
1✔
110

111
        if self.index not in _FREQ_VOCAB_FOR_IDF_STORAGES:
1✔
112
            _FREQ_VOCAB_FOR_IDF_STORAGES[self.index] = Counter()
1✔
113

114
        # keep track of whether we own the executor if we created it we must also clean it up
115
        self._owns_executor = async_executor is None
1✔
116
        self.executor = (
1✔
117
            ThreadPoolExecutor(thread_name_prefix=f"async-inmemory-docstore-executor-{id(self)}", max_workers=1)
118
            if async_executor is None
119
            else async_executor
120
        )
121

122
    def __del__(self):
1✔
123
        """
124
        Cleanup when the instance is being destroyed.
125
        """
126
        if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
1✔
127
            self.executor.shutdown(wait=True)
1✔
128

129
    def shutdown(self):
1✔
130
        """
131
        Explicitly shutdown the executor if we own it.
132
        """
133
        if self._owns_executor:
×
134
            self.executor.shutdown(wait=True)
×
135

136
    @property
1✔
137
    def storage(self) -> Dict[str, Document]:
1✔
138
        """
139
        Utility property that returns the storage used by this instance of InMemoryDocumentStore.
140
        """
141
        return _STORAGES.get(self.index, {})
1✔
142

143
    @property
1✔
144
    def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
1✔
145
        return _BM25_STATS_STORAGES.get(self.index, {})
1✔
146

147
    @property
1✔
148
    def _avg_doc_len(self) -> float:
1✔
149
        return _AVERAGE_DOC_LEN_STORAGES.get(self.index, 0.0)
1✔
150

151
    @_avg_doc_len.setter
1✔
152
    def _avg_doc_len(self, value: float):
1✔
153
        _AVERAGE_DOC_LEN_STORAGES[self.index] = value
1✔
154

155
    @property
1✔
156
    def _freq_vocab_for_idf(self) -> Counter:
1✔
157
        return _FREQ_VOCAB_FOR_IDF_STORAGES.get(self.index, Counter())
1✔
158

159
    def _dispatch_bm25(self):
1✔
160
        """
161
        Select the correct BM25 algorithm based on user specification.
162

163
        :returns:
164
            The BM25 algorithm method.
165
        """
166
        table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}
1✔
167

168
        if self.bm25_algorithm not in table:
1✔
169
            raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
1✔
170
        return table[self.bm25_algorithm]
1✔
171

172
    def _tokenize_bm25(self, text: str) -> List[str]:
1✔
173
        """
174
        Tokenize text using the BM25 tokenization regex.
175

176
        Here we explicitly create a tokenization method to encapsulate
177
        all pre-processing logic used to create BM25 tokens, such as
178
        lowercasing. This helps track the exact tokenization process
179
        used for BM25 scoring at any given time.
180

181
        :param text:
182
            The text to tokenize.
183
        :returns:
184
            A list of tokens.
185
        """
186
        text = text.lower()
1✔
187
        return self.tokenizer(text)
1✔
188

189
    def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
190
        """
191
        Calculate BM25L scores for the given query and filtered documents.
192

193
        :param query:
194
            The query string.
195
        :param documents:
196
            The list of documents to score, should be produced by
197
            the filter_documents method; may be an empty list.
198
        :returns:
199
            A list of tuples, each containing a Document and its BM25L score.
200
        """
201
        k = self.bm25_parameters.get("k1", 1.5)
1✔
202
        b = self.bm25_parameters.get("b", 0.75)
1✔
203
        delta = self.bm25_parameters.get("delta", 0.5)
1✔
204

205
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
206
            """Per-token IDF computation for all tokens."""
207
            idf = {}
1✔
208
            n_corpus = len(self._bm25_attr)
1✔
209
            for tok in tokens:
1✔
210
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
211
                idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
1✔
212
            return idf
1✔
213

214
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
215
            """Per-token BM25L computation."""
216
            freq_term = freq.get(token, 0.0)
1✔
217
            ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
1✔
218
            return (1.0 + k) * (ctd + delta) / (k + ctd + delta)
1✔
219

220
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
221
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
222

223
        ret = []
1✔
224
        for doc in documents:
1✔
225
            doc_stats = bm25_attr[doc.id]
1✔
226
            freq = doc_stats.freq_token
1✔
227
            doc_len = doc_stats.doc_len
1✔
228

229
            score = 0.0
1✔
230
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
231
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
232
            ret.append((doc, score))
1✔
233

234
        return ret
1✔
235

236
    def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
237
        """
238
        Calculate BM25Okapi scores for the given query and filtered documents.
239

240
        :param query:
241
            The query string.
242
        :param documents:
243
            The list of documents to score, should be produced by
244
            the filter_documents method; may be an empty list.
245
        :returns:
246
            A list of tuples, each containing a Document and its BM25L score.
247
        """
248
        k = self.bm25_parameters.get("k1", 1.5)
1✔
249
        b = self.bm25_parameters.get("b", 0.75)
1✔
250
        epsilon = self.bm25_parameters.get("epsilon", 0.25)
1✔
251

252
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
253
            """Per-token IDF computation for all tokens."""
254
            sum_idf = 0.0
1✔
255
            neg_idf_tokens = []
1✔
256

257
            # Although this is a global statistic, we compute it here
258
            # to make the computation more self-contained. And the
259
            # complexity is O(vocab_size), which is acceptable.
260
            idf = {}
1✔
261
            for tok, n in self._freq_vocab_for_idf.items():
1✔
262
                idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
1✔
263
                sum_idf += idf[tok]
1✔
264
                if idf[tok] < 0:
1✔
265
                    neg_idf_tokens.append(tok)
1✔
266

267
            eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
1✔
268
            for tok in neg_idf_tokens:
1✔
269
                idf[tok] = eps
1✔
270
            return {tok: idf.get(tok, 0.0) for tok in tokens}
1✔
271

272
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
273
            """Per-token BM25L computation."""
274
            freq_term = freq.get(token, 0.0)
1✔
275
            freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
276
            return freq_term * (1.0 + k) / freq_norm
1✔
277

278
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
279
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
280

281
        ret = []
1✔
282
        for doc in documents:
1✔
283
            doc_stats = bm25_attr[doc.id]
1✔
284
            freq = doc_stats.freq_token
1✔
285
            doc_len = doc_stats.doc_len
1✔
286

287
            score = 0.0
1✔
288
            for tok in idf.keys():
1✔
289
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
290
            ret.append((doc, score))
1✔
291

292
        return ret
1✔
293

294
    def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
295
        """
296
        Calculate BM25+ scores for the given query and filtered documents.
297

298
        This implementation follows the document on BM25 Wikipedia page,
299
        which add 1 (smoothing factor) to document frequency when computing IDF.
300

301
        :param query:
302
            The query string.
303
        :param documents:
304
            The list of documents to score, should be produced by
305
            the filter_documents method; may be an empty list.
306
        :returns:
307
            A list of tuples, each containing a Document and its BM25+ score.
308
        """
309
        k = self.bm25_parameters.get("k1", 1.5)
1✔
310
        b = self.bm25_parameters.get("b", 0.75)
1✔
311
        delta = self.bm25_parameters.get("delta", 1.0)
1✔
312

313
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
314
            """Per-token IDF computation."""
315
            idf = {}
1✔
316
            n_corpus = len(self._bm25_attr)
1✔
317
            for tok in tokens:
1✔
318
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
319
                idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
1✔
320
            return idf
1✔
321

322
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
1✔
323
            """Per-token normalized term frequency."""
324
            freq_term = freq.get(token, 0.0)
1✔
325
            freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
326
            return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta
1✔
327

328
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
329
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
330

331
        ret = []
1✔
332
        for doc in documents:
1✔
333
            doc_stats = bm25_attr[doc.id]
1✔
334
            freq = doc_stats.freq_token
1✔
335
            doc_len = doc_stats.doc_len
1✔
336

337
            score = 0.0
1✔
338
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
339
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
340
            ret.append((doc, score))
1✔
341

342
        return ret
1✔
343

344
    def to_dict(self) -> Dict[str, Any]:
1✔
345
        """
346
        Serializes the component to a dictionary.
347

348
        :returns:
349
            Dictionary with serialized data.
350
        """
351
        return default_to_dict(
1✔
352
            self,
353
            bm25_tokenization_regex=self.bm25_tokenization_regex,
354
            bm25_algorithm=self.bm25_algorithm,
355
            bm25_parameters=self.bm25_parameters,
356
            embedding_similarity_function=self.embedding_similarity_function,
357
            index=self.index,
358
        )
359

360
    @classmethod
1✔
361
    def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore":
1✔
362
        """
363
        Deserializes the component from a dictionary.
364

365
        :param data:
366
            The dictionary to deserialize from.
367
        :returns:
368
            The deserialized component.
369
        """
370
        return default_from_dict(cls, data)
1✔
371

372
    def save_to_disk(self, path: str) -> None:
1✔
373
        """
374
        Write the database and its' data to disk as a JSON file.
375

376
        :param path: The path to the JSON file.
377
        """
378
        data: Dict[str, Any] = self.to_dict()
1✔
379
        data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
1✔
380
        with open(path, "w") as f:
1✔
381
            json.dump(data, f)
1✔
382

383
    @classmethod
1✔
384
    def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
1✔
385
        """
386
        Load the database and its' data from disk as a JSON file.
387

388
        :param path: The path to the JSON file.
389
        :returns: The loaded InMemoryDocumentStore.
390
        """
391
        if Path(path).exists():
1✔
392
            try:
1✔
393
                with open(path, "r") as f:
1✔
394
                    data = json.load(f)
1✔
395
            except Exception as e:
×
396
                raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
×
397

398
            documents = data.pop("documents")
1✔
399
            cls_object = default_from_dict(cls, data)
1✔
400
            cls_object.write_documents(
1✔
401
                documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
402
            )
403
            return cls_object
1✔
404

405
        else:
406
            raise FileNotFoundError(f"File {path} not found.")
×
407

408
    def count_documents(self) -> int:
1✔
409
        """
410
        Returns the number of how many documents are present in the DocumentStore.
411
        """
412
        return len(self.storage.keys())
1✔
413

414
    def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
1✔
415
        """
416
        Returns the documents that match the filters provided.
417

418
        For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol
419
        documentation.
420

421
        :param filters: The filters to apply to the document list.
422
        :returns: A list of Documents that match the given filters.
423
        """
424
        if filters:
1✔
425
            if "operator" not in filters and "conditions" not in filters:
1✔
426
                raise ValueError(
×
427
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
428
                )
429
            return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)]
1✔
430
        return list(self.storage.values())
1✔
431

432
    def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
1✔
433
        """
434
        Refer to the DocumentStore.write_documents() protocol documentation.
435

436
        If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
437
        """
438
        if (
1✔
439
            not isinstance(documents, Iterable)
440
            or isinstance(documents, str)
441
            or any(not isinstance(doc, Document) for doc in documents)
442
        ):
443
            raise ValueError("Please provide a list of Documents.")
1✔
444

445
        if policy == DuplicatePolicy.NONE:
1✔
446
            policy = DuplicatePolicy.FAIL
1✔
447

448
        written_documents = len(documents)
1✔
449
        for document in documents:
1✔
450
            if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
1✔
451
                if policy == DuplicatePolicy.FAIL:
1✔
452
                    raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
1✔
453
                if policy == DuplicatePolicy.SKIP:
1✔
454
                    logger.warning("ID '{document_id}' already exists", document_id=document.id)
1✔
455
                    written_documents -= 1
1✔
456
                    continue
1✔
457

458
            # Since the statistics are updated in an incremental manner,
459
            # we need to explicitly remove the existing document to revert
460
            # the statistics before updating them with the new document.
461
            if document.id in self.storage.keys():
1✔
462
                self.delete_documents([document.id])
1✔
463

464
            tokens = []
1✔
465
            if document.content is not None:
1✔
466
                tokens = self._tokenize_bm25(document.content)
1✔
467

468
            self.storage[document.id] = document
1✔
469

470
            self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
1✔
471
            self._freq_vocab_for_idf.update(set(tokens))
1✔
472
            self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
1✔
473
        return written_documents
1✔
474

475
    def delete_documents(self, document_ids: List[str]) -> None:
1✔
476
        """
477
        Deletes all documents with matching document_ids from the DocumentStore.
478

479
        :param document_ids: The object_ids to delete.
480
        """
481
        for doc_id in document_ids:
1✔
482
            if doc_id not in self.storage.keys():
1✔
483
                continue
1✔
484
            del self.storage[doc_id]
1✔
485

486
            # Update statistics accordingly
487
            doc_stats = self._bm25_attr.pop(doc_id)
1✔
488
            freq = doc_stats.freq_token
1✔
489
            doc_len = doc_stats.doc_len
1✔
490

491
            self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
1✔
492
            try:
1✔
493
                self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
1✔
494
            except ZeroDivisionError:
1✔
495
                self._avg_doc_len = 0
1✔
496

497
    def bm25_retrieval(
1✔
498
        self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
499
    ) -> List[Document]:
500
        """
501
        Retrieves documents that are most relevant to the query using BM25 algorithm.
502

503
        :param query: The query string.
504
        :param filters: A dictionary with filters to narrow down the search space.
505
        :param top_k: The number of top documents to retrieve. Default is 10.
506
        :param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
507
        :returns: A list of the top_k documents most relevant to the query.
508
        """
509
        if not query:
1✔
510
            raise ValueError("Query should be a non-empty string")
1✔
511

512
        content_type_filter = {"field": "content", "operator": "!=", "value": None}
1✔
513
        if filters:
1✔
514
            if "operator" not in filters:
×
515
                raise ValueError(
×
516
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
517
                )
518
            filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
×
519
        else:
520
            filters = content_type_filter
1✔
521

522
        all_documents = self.filter_documents(filters=filters)
1✔
523
        if len(all_documents) == 0:
1✔
524
            logger.info("No documents found for BM25 retrieval. Returning empty list.")
1✔
525
            return []
1✔
526

527
        results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]
1✔
528

529
        # BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
530
        # It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
531
        # see https://github.com/deepset-ai/haystack/pull/6889 for more context.
532
        negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score
1✔
533

534
        # Create documents with the BM25 score to return them
535
        return_documents = []
1✔
536
        for doc, score in results:
1✔
537
            if scale_score:
1✔
538
                score = expit(score / BM25_SCALING_FACTOR)
1✔
539

540
            if not negatives_are_valid and score <= 0.0:
1✔
541
                continue
1✔
542

543
            doc_fields = doc.to_dict()
1✔
544
            doc_fields["score"] = score
1✔
545
            return_document = Document.from_dict(doc_fields)
1✔
546
            return_documents.append(return_document)
1✔
547

548
        return return_documents
1✔
549

550
    def embedding_retrieval(  # pylint: disable=too-many-positional-arguments
1✔
551
        self,
552
        query_embedding: List[float],
553
        filters: Optional[Dict[str, Any]] = None,
554
        top_k: int = 10,
555
        scale_score: bool = False,
556
        return_embedding: bool = False,
557
    ) -> List[Document]:
558
        """
559
        Retrieves documents that are most similar to the query embedding using a vector similarity metric.
560

561
        :param query_embedding: Embedding of the query.
562
        :param filters: A dictionary with filters to narrow down the search space.
563
        :param top_k: The number of top documents to retrieve. Default is 10.
564
        :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
565
        :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
566
        :returns: A list of the top_k documents most relevant to the query.
567
        """
568
        if len(query_embedding) == 0 or not isinstance(query_embedding[0], float):
1✔
569
            raise ValueError("query_embedding should be a non-empty list of floats.")
1✔
570

571
        filters = filters or {}
1✔
572
        all_documents = self.filter_documents(filters=filters)
1✔
573

574
        documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None]
1✔
575
        if len(documents_with_embeddings) == 0:
1✔
576
            logger.warning(
1✔
577
                "No Documents found with embeddings. Returning empty list. "
578
                "To generate embeddings, use a DocumentEmbedder."
579
            )
580
            return []
1✔
581
        elif len(documents_with_embeddings) < len(all_documents):
1✔
582
            logger.info(
1✔
583
                "Skipping some Documents that don't have an embedding. To generate embeddings, use a DocumentEmbedder."
584
            )
585

586
        scores = self._compute_query_embedding_similarity_scores(
1✔
587
            embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score
588
        )
589

590
        # create Documents with the similarity score for the top k results
591
        top_documents = []
1✔
592
        for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]:
1✔
593
            doc_fields = doc.to_dict()
1✔
594
            doc_fields["score"] = score
1✔
595
            if return_embedding is False:
1✔
596
                doc_fields["embedding"] = None
1✔
597
            top_documents.append(Document.from_dict(doc_fields))
1✔
598

599
        return top_documents
1✔
600

601
    def _compute_query_embedding_similarity_scores(
1✔
602
        self, embedding: List[float], documents: List[Document], scale_score: bool = False
603
    ) -> List[float]:
604
        """
605
        Computes the similarity scores between the query embedding and the embeddings of the documents.
606

607
        :param embedding: Embedding of the query.
608
        :param documents: A list of Documents.
609
        :param scale_score: Whether to scale the scores of the Documents. Default is False.
610
        :returns: A list of scores.
611
        """
612

613
        query_embedding = np.array(embedding)
1✔
614
        if query_embedding.ndim == 1:
1✔
615
            query_embedding = np.expand_dims(a=query_embedding, axis=0)
1✔
616

617
        try:
1✔
618
            document_embeddings = np.array([doc.embedding for doc in documents])
1✔
619
        except ValueError as e:
1✔
620
            if "inhomogeneous shape" in str(e):
1✔
621
                raise DocumentStoreError(
1✔
622
                    "The embedding size of all Documents should be the same. "
623
                    "Please make sure that the Documents have been embedded with the same model."
624
                ) from e
625
            raise e
×
626
        if document_embeddings.ndim == 1:
1✔
627
            document_embeddings = np.expand_dims(a=document_embeddings, axis=0)
×
628

629
        if self.embedding_similarity_function == "cosine":
1✔
630
            # cosine similarity is a normed dot product
631
            query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True)
1✔
632
            document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True)
1✔
633

634
        try:
1✔
635
            scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist()
1✔
636
        except ValueError as e:
1✔
637
            if "shapes" in str(e) and "not aligned" in str(e):
1✔
638
                raise DocumentStoreError(
1✔
639
                    "The embedding size of the query should be the same as the embedding size of the Documents. "
640
                    "Please make sure that the query has been embedded with the same model as the Documents."
641
                ) from e
642
            raise e
×
643

644
        if scale_score:
1✔
645
            if self.embedding_similarity_function == "dot_product":
1✔
646
                scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores]
1✔
647
            elif self.embedding_similarity_function == "cosine":
×
648
                scores = [(score + 1) / 2 for score in scores]
×
649

650
        return scores
1✔
651

652
    async def count_documents_async(self) -> int:
1✔
653
        """
654
        Returns the number of how many documents are present in the DocumentStore.
655
        """
656
        return len(self.storage.keys())
1✔
657

658
    async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
1✔
659
        """
660
        Returns the documents that match the filters provided.
661

662
        For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol
663
        documentation.
664

665
        :param filters: The filters to apply to the document list.
666
        :returns: A list of Documents that match the given filters.
667
        """
668
        return await asyncio.get_event_loop().run_in_executor(
1✔
669
            self.executor, lambda: self.filter_documents(filters=filters)
670
        )
671

672
    async def write_documents_async(
1✔
673
        self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE
674
    ) -> int:
675
        """
676
        Refer to the DocumentStore.write_documents() protocol documentation.
677

678
        If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
679
        """
680
        return await asyncio.get_event_loop().run_in_executor(
1✔
681
            self.executor, lambda: self.write_documents(documents=documents, policy=policy)
682
        )
683

684
    async def delete_documents_async(self, document_ids: List[str]) -> None:
1✔
685
        """
686
        Deletes all documents with matching document_ids from the DocumentStore.
687

688
        :param document_ids: The object_ids to delete.
689
        """
690
        await asyncio.get_event_loop().run_in_executor(
1✔
691
            self.executor, lambda: self.delete_documents(document_ids=document_ids)
692
        )
693

694
    async def bm25_retrieval_async(
1✔
695
        self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
696
    ) -> List[Document]:
697
        """
698
        Retrieves documents that are most relevant to the query using BM25 algorithm.
699

700
        :param query: The query string.
701
        :param filters: A dictionary with filters to narrow down the search space.
702
        :param top_k: The number of top documents to retrieve. Default is 10.
703
        :param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
704
        :returns: A list of the top_k documents most relevant to the query.
705
        """
706
        return await asyncio.get_event_loop().run_in_executor(
1✔
707
            self.executor,
708
            lambda: self.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score),
709
        )
710

711
    async def embedding_retrieval_async(  # pylint: disable=too-many-positional-arguments
1✔
712
        self,
713
        query_embedding: List[float],
714
        filters: Optional[Dict[str, Any]] = None,
715
        top_k: int = 10,
716
        scale_score: bool = False,
717
        return_embedding: bool = False,
718
    ) -> List[Document]:
719
        """
720
        Retrieves documents that are most similar to the query embedding using a vector similarity metric.
721

722
        :param query_embedding: Embedding of the query.
723
        :param filters: A dictionary with filters to narrow down the search space.
724
        :param top_k: The number of top documents to retrieve. Default is 10.
725
        :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
726
        :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
727
        :returns: A list of the top_k documents most relevant to the query.
728
        """
729
        return await asyncio.get_event_loop().run_in_executor(
1✔
730
            self.executor,
731
            lambda: self.embedding_retrieval(
732
                query_embedding=query_embedding,
733
                filters=filters,
734
                top_k=top_k,
735
                scale_score=scale_score,
736
                return_embedding=return_embedding,
737
            ),
738
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc