• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13674010844

05 Mar 2025 10:36AM UTC coverage: 90.066% (-0.2%) from 90.218%
13674010844

push

github

web-flow
feat: adding async version of `InMemoryDocumentStore` and associated retrievers (#8963)

* adding classes from experimental

* adding release notes

* adding tests

* merging all into a single class

* adding async retriever methods

* Update haystack/document_stores/in_memory/document_store.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* adding missed tests

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

9620 of 10681 relevant lines covered (90.07%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.14
haystack/document_stores/in_memory/document_store.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import math
1✔
8
import re
1✔
9
import uuid
1✔
10
from collections import Counter
1✔
11
from concurrent.futures import ThreadPoolExecutor
1✔
12
from dataclasses import dataclass
1✔
13
from pathlib import Path
1✔
14
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
1✔
15

16
import numpy as np
1✔
17

18
from haystack import default_from_dict, default_to_dict, logging
1✔
19
from haystack.dataclasses import Document
1✔
20
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1✔
21
from haystack.document_stores.types import DuplicatePolicy
1✔
22
from haystack.utils import expit
1✔
23
from haystack.utils.filters import document_matches_filter
1✔
24

25
logger = logging.getLogger(__name__)
1✔
26

27
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
28
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
29
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
30
# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with
31
# BM25_SCALING_FACTOR=2 but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically.
32
# Increase the default if most unscaled scores are larger than expected (>30) and otherwise would incorrectly all be
33
# mapped to scores ~1.
34
BM25_SCALING_FACTOR = 8
1✔
35
DOT_PRODUCT_SCALING_FACTOR = 100
1✔
36

37

38
@dataclass
1✔
39
class BM25DocumentStats:
1✔
40
    """
41
    A dataclass for managing document statistics for BM25 retrieval.
42

43
    :param freq_token: A Counter of token frequencies in the document.
44
    :param doc_len: Number of tokens in the document.
45
    """
46

47
    freq_token: Dict[str, int]
1✔
48
    doc_len: int
1✔
49

50

51
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
52
_STORAGES: Dict[str, Dict[str, Document]] = {}
1✔
53
_BM25_STATS_STORAGES: Dict[str, Dict[str, BM25DocumentStats]] = {}
1✔
54
_AVERAGE_DOC_LEN_STORAGES: Dict[str, float] = {}
1✔
55
_FREQ_VOCAB_FOR_IDF_STORAGES: Dict[str, Counter] = {}
1✔
56

57

58
class InMemoryDocumentStore:
1✔
59
    """
60
    Stores data in-memory. It's ephemeral and cannot be saved to disk.
61
    """
62

63
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
64
        self,
65
        bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
66
        bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
67
        bm25_parameters: Optional[Dict] = None,
68
        embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
69
        index: Optional[str] = None,
70
        async_executor: Optional[ThreadPoolExecutor] = None,
71
    ):
72
        """
73
        Initializes the DocumentStore.
74

75
        :param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
76
        :param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
77
        :param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
78
            For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
79
            You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25.
80
        :param embedding_similarity_function: The similarity function used to compare Documents embeddings.
81
            One of "dot_product" (default) or "cosine". To choose the most appropriate function, look for information
82
            about your embedding model.
83
        :param index: A specific index to store the documents. If not specified, a random UUID is used.
84
            Using the same index allows you to store documents across multiple InMemoryDocumentStore instances.
85
        :param async_executor:
86
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded
87
            executor will be initialized and used.
88
        """
89
        self.bm25_tokenization_regex = bm25_tokenization_regex
1✔
90
        self.tokenizer = re.compile(bm25_tokenization_regex).findall
1✔
91

92
        if index is None:
1✔
93
            index = str(uuid.uuid4())
1✔
94

95
        self.index = index
1✔
96
        if self.index not in _STORAGES:
1✔
97
            _STORAGES[self.index] = {}
1✔
98

99
        self.bm25_algorithm = bm25_algorithm
1✔
100
        self.bm25_algorithm_inst = self._dispatch_bm25()
1✔
101
        self.bm25_parameters = bm25_parameters or {}
1✔
102
        self.embedding_similarity_function = embedding_similarity_function
1✔
103

104
        # Per-document statistics
105
        if self.index not in _BM25_STATS_STORAGES:
1✔
106
            _BM25_STATS_STORAGES[self.index] = {}
1✔
107

108
        if self.index not in _AVERAGE_DOC_LEN_STORAGES:
1✔
109
            _AVERAGE_DOC_LEN_STORAGES[self.index] = 0.0
1✔
110

111
        if self.index not in _FREQ_VOCAB_FOR_IDF_STORAGES:
1✔
112
            _FREQ_VOCAB_FOR_IDF_STORAGES[self.index] = Counter()
1✔
113

114
        self.executor = (
1✔
115
            ThreadPoolExecutor(thread_name_prefix=f"async-inmemory-docstore-executor-{id(self)}", max_workers=1)
116
            if async_executor is None
117
            else async_executor
118
        )
119

120
    @property
1✔
121
    def storage(self) -> Dict[str, Document]:
1✔
122
        """
123
        Utility property that returns the storage used by this instance of InMemoryDocumentStore.
124
        """
125
        return _STORAGES.get(self.index, {})
1✔
126

127
    @property
1✔
128
    def _bm25_attr(self) -> Dict[str, BM25DocumentStats]:
1✔
129
        return _BM25_STATS_STORAGES.get(self.index, {})
1✔
130

131
    @property
1✔
132
    def _avg_doc_len(self) -> float:
1✔
133
        return _AVERAGE_DOC_LEN_STORAGES.get(self.index, 0.0)
1✔
134

135
    @_avg_doc_len.setter
1✔
136
    def _avg_doc_len(self, value: float):
1✔
137
        _AVERAGE_DOC_LEN_STORAGES[self.index] = value
1✔
138

139
    @property
1✔
140
    def _freq_vocab_for_idf(self) -> Counter:
1✔
141
        return _FREQ_VOCAB_FOR_IDF_STORAGES.get(self.index, Counter())
1✔
142

143
    def _dispatch_bm25(self):
1✔
144
        """
145
        Select the correct BM25 algorithm based on user specification.
146

147
        :returns:
148
            The BM25 algorithm method.
149
        """
150
        table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}
1✔
151

152
        if self.bm25_algorithm not in table:
1✔
153
            raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
1✔
154
        return table[self.bm25_algorithm]
1✔
155

156
    def _tokenize_bm25(self, text: str) -> List[str]:
1✔
157
        """
158
        Tokenize text using the BM25 tokenization regex.
159

160
        Here we explicitly create a tokenization method to encapsulate
161
        all pre-processing logic used to create BM25 tokens, such as
162
        lowercasing. This helps track the exact tokenization process
163
        used for BM25 scoring at any given time.
164

165
        :param text:
166
            The text to tokenize.
167
        :returns:
168
            A list of tokens.
169
        """
170
        text = text.lower()
1✔
171
        return self.tokenizer(text)
1✔
172

173
    def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
174
        """
175
        Calculate BM25L scores for the given query and filtered documents.
176

177
        :param query:
178
            The query string.
179
        :param documents:
180
            The list of documents to score, should be produced by
181
            the filter_documents method; may be an empty list.
182
        :returns:
183
            A list of tuples, each containing a Document and its BM25L score.
184
        """
185
        k = self.bm25_parameters.get("k1", 1.5)
1✔
186
        b = self.bm25_parameters.get("b", 0.75)
1✔
187
        delta = self.bm25_parameters.get("delta", 0.5)
1✔
188

189
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
190
            """Per-token IDF computation for all tokens."""
191
            idf = {}
1✔
192
            n_corpus = len(self._bm25_attr)
1✔
193
            for tok in tokens:
1✔
194
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
195
                idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
1✔
196
            return idf
1✔
197

198
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
199
            """Per-token BM25L computation."""
200
            freq_term = freq.get(token, 0.0)
1✔
201
            ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
1✔
202
            return (1.0 + k) * (ctd + delta) / (k + ctd + delta)
1✔
203

204
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
205
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
206

207
        ret = []
1✔
208
        for doc in documents:
1✔
209
            doc_stats = bm25_attr[doc.id]
1✔
210
            freq = doc_stats.freq_token
1✔
211
            doc_len = doc_stats.doc_len
1✔
212

213
            score = 0.0
1✔
214
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
215
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
216
            ret.append((doc, score))
1✔
217

218
        return ret
1✔
219

220
    def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
221
        """
222
        Calculate BM25Okapi scores for the given query and filtered documents.
223

224
        :param query:
225
            The query string.
226
        :param documents:
227
            The list of documents to score, should be produced by
228
            the filter_documents method; may be an empty list.
229
        :returns:
230
            A list of tuples, each containing a Document and its BM25L score.
231
        """
232
        k = self.bm25_parameters.get("k1", 1.5)
1✔
233
        b = self.bm25_parameters.get("b", 0.75)
1✔
234
        epsilon = self.bm25_parameters.get("epsilon", 0.25)
1✔
235

236
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
237
            """Per-token IDF computation for all tokens."""
238
            sum_idf = 0.0
1✔
239
            neg_idf_tokens = []
1✔
240

241
            # Although this is a global statistic, we compute it here
242
            # to make the computation more self-contained. And the
243
            # complexity is O(vocab_size), which is acceptable.
244
            idf = {}
1✔
245
            for tok, n in self._freq_vocab_for_idf.items():
1✔
246
                idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
1✔
247
                sum_idf += idf[tok]
1✔
248
                if idf[tok] < 0:
1✔
249
                    neg_idf_tokens.append(tok)
1✔
250

251
            eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
1✔
252
            for tok in neg_idf_tokens:
1✔
253
                idf[tok] = eps
1✔
254
            return {tok: idf.get(tok, 0.0) for tok in tokens}
1✔
255

256
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
1✔
257
            """Per-token BM25L computation."""
258
            freq_term = freq.get(token, 0.0)
1✔
259
            freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
260
            return freq_term * (1.0 + k) / freq_norm
1✔
261

262
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
263
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
264

265
        ret = []
1✔
266
        for doc in documents:
1✔
267
            doc_stats = bm25_attr[doc.id]
1✔
268
            freq = doc_stats.freq_token
1✔
269
            doc_len = doc_stats.doc_len
1✔
270

271
            score = 0.0
1✔
272
            for tok in idf.keys():
1✔
273
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
274
            ret.append((doc, score))
1✔
275

276
        return ret
1✔
277

278
    def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
1✔
279
        """
280
        Calculate BM25+ scores for the given query and filtered documents.
281

282
        This implementation follows the document on BM25 Wikipedia page,
283
        which add 1 (smoothing factor) to document frequency when computing IDF.
284

285
        :param query:
286
            The query string.
287
        :param documents:
288
            The list of documents to score, should be produced by
289
            the filter_documents method; may be an empty list.
290
        :returns:
291
            A list of tuples, each containing a Document and its BM25+ score.
292
        """
293
        k = self.bm25_parameters.get("k1", 1.5)
1✔
294
        b = self.bm25_parameters.get("b", 0.75)
1✔
295
        delta = self.bm25_parameters.get("delta", 1.0)
1✔
296

297
        def _compute_idf(tokens: List[str]) -> Dict[str, float]:
1✔
298
            """Per-token IDF computation."""
299
            idf = {}
1✔
300
            n_corpus = len(self._bm25_attr)
1✔
301
            for tok in tokens:
1✔
302
                n = self._freq_vocab_for_idf.get(tok, 0)
1✔
303
                idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
1✔
304
            return idf
1✔
305

306
        def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
1✔
307
            """Per-token normalized term frequency."""
308
            freq_term = freq.get(token, 0.0)
1✔
309
            freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
1✔
310
            return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta
1✔
311

312
        idf = _compute_idf(self._tokenize_bm25(query))
1✔
313
        bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
1✔
314

315
        ret = []
1✔
316
        for doc in documents:
1✔
317
            doc_stats = bm25_attr[doc.id]
1✔
318
            freq = doc_stats.freq_token
1✔
319
            doc_len = doc_stats.doc_len
1✔
320

321
            score = 0.0
1✔
322
            for tok in idf.keys():  # pylint: disable=consider-using-dict-items
1✔
323
                score += idf[tok] * _compute_tf(tok, freq, doc_len)
1✔
324
            ret.append((doc, score))
1✔
325

326
        return ret
1✔
327

328
    def to_dict(self) -> Dict[str, Any]:
1✔
329
        """
330
        Serializes the component to a dictionary.
331

332
        :returns:
333
            Dictionary with serialized data.
334
        """
335
        return default_to_dict(
1✔
336
            self,
337
            bm25_tokenization_regex=self.bm25_tokenization_regex,
338
            bm25_algorithm=self.bm25_algorithm,
339
            bm25_parameters=self.bm25_parameters,
340
            embedding_similarity_function=self.embedding_similarity_function,
341
            index=self.index,
342
        )
343

344
    @classmethod
1✔
345
    def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore":
1✔
346
        """
347
        Deserializes the component from a dictionary.
348

349
        :param data:
350
            The dictionary to deserialize from.
351
        :returns:
352
            The deserialized component.
353
        """
354
        return default_from_dict(cls, data)
1✔
355

356
    def save_to_disk(self, path: str) -> None:
1✔
357
        """
358
        Write the database and its' data to disk as a JSON file.
359

360
        :param path: The path to the JSON file.
361
        """
362
        data: Dict[str, Any] = self.to_dict()
1✔
363
        data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
1✔
364
        with open(path, "w") as f:
1✔
365
            json.dump(data, f)
1✔
366

367
    @classmethod
1✔
368
    def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
1✔
369
        """
370
        Load the database and its' data from disk as a JSON file.
371

372
        :param path: The path to the JSON file.
373
        :returns: The loaded InMemoryDocumentStore.
374
        """
375
        if Path(path).exists():
1✔
376
            try:
1✔
377
                with open(path, "r") as f:
1✔
378
                    data = json.load(f)
1✔
379
            except Exception as e:
×
380
                raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")
×
381

382
            documents = data.pop("documents")
1✔
383
            cls_object = default_from_dict(cls, data)
1✔
384
            cls_object.write_documents(
1✔
385
                documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
386
            )
387
            return cls_object
1✔
388

389
        else:
390
            raise FileNotFoundError(f"File {path} not found.")
×
391

392
    def count_documents(self) -> int:
1✔
393
        """
394
        Returns the number of how many documents are present in the DocumentStore.
395
        """
396
        return len(self.storage.keys())
1✔
397

398
    def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
1✔
399
        """
400
        Returns the documents that match the filters provided.
401

402
        For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol
403
        documentation.
404

405
        :param filters: The filters to apply to the document list.
406
        :returns: A list of Documents that match the given filters.
407
        """
408
        if filters:
1✔
409
            if "operator" not in filters and "conditions" not in filters:
1✔
410
                raise ValueError(
×
411
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
412
                )
413
            return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)]
1✔
414
        return list(self.storage.values())
1✔
415

416
    def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
1✔
417
        """
418
        Refer to the DocumentStore.write_documents() protocol documentation.
419

420
        If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
421
        """
422
        if (
1✔
423
            not isinstance(documents, Iterable)
424
            or isinstance(documents, str)
425
            or any(not isinstance(doc, Document) for doc in documents)
426
        ):
427
            raise ValueError("Please provide a list of Documents.")
1✔
428

429
        if policy == DuplicatePolicy.NONE:
1✔
430
            policy = DuplicatePolicy.FAIL
1✔
431

432
        written_documents = len(documents)
1✔
433
        for document in documents:
1✔
434
            if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
1✔
435
                if policy == DuplicatePolicy.FAIL:
1✔
436
                    raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
1✔
437
                if policy == DuplicatePolicy.SKIP:
1✔
438
                    logger.warning("ID '{document_id}' already exists", document_id=document.id)
1✔
439
                    written_documents -= 1
1✔
440
                    continue
1✔
441

442
            # Since the statistics are updated in an incremental manner,
443
            # we need to explicitly remove the existing document to revert
444
            # the statistics before updating them with the new document.
445
            if document.id in self.storage.keys():
1✔
446
                self.delete_documents([document.id])
1✔
447

448
            tokens = []
1✔
449
            if document.content is not None:
1✔
450
                tokens = self._tokenize_bm25(document.content)
1✔
451

452
            self.storage[document.id] = document
1✔
453

454
            self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
1✔
455
            self._freq_vocab_for_idf.update(set(tokens))
1✔
456
            self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
1✔
457
        return written_documents
1✔
458

459
    def delete_documents(self, document_ids: List[str]) -> None:
1✔
460
        """
461
        Deletes all documents with matching document_ids from the DocumentStore.
462

463
        :param document_ids: The object_ids to delete.
464
        """
465
        for doc_id in document_ids:
1✔
466
            if doc_id not in self.storage.keys():
1✔
467
                continue
1✔
468
            del self.storage[doc_id]
1✔
469

470
            # Update statistics accordingly
471
            doc_stats = self._bm25_attr.pop(doc_id)
1✔
472
            freq = doc_stats.freq_token
1✔
473
            doc_len = doc_stats.doc_len
1✔
474

475
            self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
1✔
476
            try:
1✔
477
                self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
1✔
478
            except ZeroDivisionError:
1✔
479
                self._avg_doc_len = 0
1✔
480

481
    def bm25_retrieval(
1✔
482
        self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
483
    ) -> List[Document]:
484
        """
485
        Retrieves documents that are most relevant to the query using BM25 algorithm.
486

487
        :param query: The query string.
488
        :param filters: A dictionary with filters to narrow down the search space.
489
        :param top_k: The number of top documents to retrieve. Default is 10.
490
        :param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
491
        :returns: A list of the top_k documents most relevant to the query.
492
        """
493
        if not query:
1✔
494
            raise ValueError("Query should be a non-empty string")
1✔
495

496
        content_type_filter = {"field": "content", "operator": "!=", "value": None}
1✔
497
        if filters:
1✔
498
            if "operator" not in filters:
×
499
                raise ValueError(
×
500
                    "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
501
                )
502
            filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
×
503
        else:
504
            filters = content_type_filter
1✔
505

506
        all_documents = self.filter_documents(filters=filters)
1✔
507
        if len(all_documents) == 0:
1✔
508
            logger.info("No documents found for BM25 retrieval. Returning empty list.")
1✔
509
            return []
1✔
510

511
        results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]
1✔
512

513
        # BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
514
        # It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
515
        # see https://github.com/deepset-ai/haystack/pull/6889 for more context.
516
        negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score
1✔
517

518
        # Create documents with the BM25 score to return them
519
        return_documents = []
1✔
520
        for doc, score in results:
1✔
521
            if scale_score:
1✔
522
                score = expit(score / BM25_SCALING_FACTOR)
1✔
523

524
            if not negatives_are_valid and score <= 0.0:
1✔
525
                continue
1✔
526

527
            doc_fields = doc.to_dict()
1✔
528
            doc_fields["score"] = score
1✔
529
            return_document = Document.from_dict(doc_fields)
1✔
530
            return_documents.append(return_document)
1✔
531

532
        return return_documents
1✔
533

534
    def embedding_retrieval(  # pylint: disable=too-many-positional-arguments
1✔
535
        self,
536
        query_embedding: List[float],
537
        filters: Optional[Dict[str, Any]] = None,
538
        top_k: int = 10,
539
        scale_score: bool = False,
540
        return_embedding: bool = False,
541
    ) -> List[Document]:
542
        """
543
        Retrieves documents that are most similar to the query embedding using a vector similarity metric.
544

545
        :param query_embedding: Embedding of the query.
546
        :param filters: A dictionary with filters to narrow down the search space.
547
        :param top_k: The number of top documents to retrieve. Default is 10.
548
        :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
549
        :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
550
        :returns: A list of the top_k documents most relevant to the query.
551
        """
552
        if len(query_embedding) == 0 or not isinstance(query_embedding[0], float):
1✔
553
            raise ValueError("query_embedding should be a non-empty list of floats.")
1✔
554

555
        filters = filters or {}
1✔
556
        all_documents = self.filter_documents(filters=filters)
1✔
557

558
        documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None]
1✔
559
        if len(documents_with_embeddings) == 0:
1✔
560
            logger.warning(
1✔
561
                "No Documents found with embeddings. Returning empty list. "
562
                "To generate embeddings, use a DocumentEmbedder."
563
            )
564
            return []
1✔
565
        elif len(documents_with_embeddings) < len(all_documents):
1✔
566
            logger.info(
1✔
567
                "Skipping some Documents that don't have an embedding. To generate embeddings, use a DocumentEmbedder."
568
            )
569

570
        scores = self._compute_query_embedding_similarity_scores(
1✔
571
            embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score
572
        )
573

574
        # create Documents with the similarity score for the top k results
575
        top_documents = []
1✔
576
        for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]:
1✔
577
            doc_fields = doc.to_dict()
1✔
578
            doc_fields["score"] = score
1✔
579
            if return_embedding is False:
1✔
580
                doc_fields["embedding"] = None
1✔
581
            top_documents.append(Document.from_dict(doc_fields))
1✔
582

583
        return top_documents
1✔
584

585
    def _compute_query_embedding_similarity_scores(
1✔
586
        self, embedding: List[float], documents: List[Document], scale_score: bool = False
587
    ) -> List[float]:
588
        """
589
        Computes the similarity scores between the query embedding and the embeddings of the documents.
590

591
        :param embedding: Embedding of the query.
592
        :param documents: A list of Documents.
593
        :param scale_score: Whether to scale the scores of the Documents. Default is False.
594
        :returns: A list of scores.
595
        """
596

597
        query_embedding = np.array(embedding)
1✔
598
        if query_embedding.ndim == 1:
1✔
599
            query_embedding = np.expand_dims(a=query_embedding, axis=0)
1✔
600

601
        try:
1✔
602
            document_embeddings = np.array([doc.embedding for doc in documents])
1✔
603
        except ValueError as e:
1✔
604
            if "inhomogeneous shape" in str(e):
1✔
605
                raise DocumentStoreError(
1✔
606
                    "The embedding size of all Documents should be the same. "
607
                    "Please make sure that the Documents have been embedded with the same model."
608
                ) from e
609
            raise e
×
610
        if document_embeddings.ndim == 1:
1✔
611
            document_embeddings = np.expand_dims(a=document_embeddings, axis=0)
×
612

613
        if self.embedding_similarity_function == "cosine":
1✔
614
            # cosine similarity is a normed dot product
615
            query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True)
1✔
616
            document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True)
1✔
617

618
        try:
1✔
619
            scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist()
1✔
620
        except ValueError as e:
1✔
621
            if "shapes" in str(e) and "not aligned" in str(e):
1✔
622
                raise DocumentStoreError(
1✔
623
                    "The embedding size of the query should be the same as the embedding size of the Documents. "
624
                    "Please make sure that the query has been embedded with the same model as the Documents."
625
                ) from e
626
            raise e
×
627

628
        if scale_score:
1✔
629
            if self.embedding_similarity_function == "dot_product":
1✔
630
                scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores]
1✔
631
            elif self.embedding_similarity_function == "cosine":
×
632
                scores = [(score + 1) / 2 for score in scores]
×
633

634
        return scores
1✔
635

636
    async def count_documents_async(self) -> int:
1✔
637
        """
638
        Returns the number of how many documents are present in the DocumentStore.
639
        """
640
        return len(self.storage.keys())
1✔
641

642
    async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
1✔
643
        """
644
        Returns the documents that match the filters provided.
645

646
        For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol
647
        documentation.
648

649
        :param filters: The filters to apply to the document list.
650
        :returns: A list of Documents that match the given filters.
651
        """
652
        return await asyncio.get_event_loop().run_in_executor(
1✔
653
            self.executor, lambda: self.filter_documents(filters=filters)
654
        )
655

656
    async def write_documents_async(
1✔
657
        self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE
658
    ) -> int:
659
        """
660
        Refer to the DocumentStore.write_documents() protocol documentation.
661

662
        If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
663
        """
664
        return await asyncio.get_event_loop().run_in_executor(
1✔
665
            self.executor, lambda: self.write_documents(documents=documents, policy=policy)
666
        )
667

668
    async def delete_documents_async(self, document_ids: List[str]) -> None:
1✔
669
        """
670
        Deletes all documents with matching document_ids from the DocumentStore.
671

672
        :param document_ids: The object_ids to delete.
673
        """
674
        await asyncio.get_event_loop().run_in_executor(
1✔
675
            self.executor, lambda: self.delete_documents(document_ids=document_ids)
676
        )
677

678
    async def bm25_retrieval_async(
1✔
679
        self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
680
    ) -> List[Document]:
681
        """
682
        Retrieves documents that are most relevant to the query using BM25 algorithm.
683

684
        :param query: The query string.
685
        :param filters: A dictionary with filters to narrow down the search space.
686
        :param top_k: The number of top documents to retrieve. Default is 10.
687
        :param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
688
        :returns: A list of the top_k documents most relevant to the query.
689
        """
690
        return await asyncio.get_event_loop().run_in_executor(
1✔
691
            self.executor,
692
            lambda: self.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score),
693
        )
694

695
    async def embedding_retrieval_async(  # pylint: disable=too-many-positional-arguments
1✔
696
        self,
697
        query_embedding: List[float],
698
        filters: Optional[Dict[str, Any]] = None,
699
        top_k: int = 10,
700
        scale_score: bool = False,
701
        return_embedding: bool = False,
702
    ) -> List[Document]:
703
        """
704
        Retrieves documents that are most similar to the query embedding using a vector similarity metric.
705

706
        :param query_embedding: Embedding of the query.
707
        :param filters: A dictionary with filters to narrow down the search space.
708
        :param top_k: The number of top documents to retrieve. Default is 10.
709
        :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
710
        :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
711
        :returns: A list of the top_k documents most relevant to the query.
712
        """
713
        return await asyncio.get_event_loop().run_in_executor(
1✔
714
            self.executor,
715
            lambda: self.embedding_retrieval(
716
                query_embedding=query_embedding,
717
                filters=filters,
718
                top_k=top_k,
719
                scale_score=scale_score,
720
                return_embedding=return_embedding,
721
            ),
722
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc