• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datastax / graph-rag / 14249588432

03 Apr 2025 05:56PM CUT coverage: 93.649% (+0.06%) from 93.585%
14249588432

Pull #174

github

web-flow
Merge 651cbfc30 into 9561d9f93
Pull Request #174: simplified astra adapter implementation

387 of 440 branches covered (87.95%)

Branch coverage included in aggregate %.

103 of 103 new or added lines in 1 file covered. (100.0%)

1 existing line in 1 file now uncovered.

1589 of 1670 relevant lines covered (95.15%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.82
/packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/astra.py
1
"""Provides an adapter for AstraDB vector store integration."""
2

3
from __future__ import annotations
1✔
4

5
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
1✔
6
from typing import Any, Literal, cast, overload
1✔
7

8
import backoff
1✔
9
from graph_retriever import Content
1✔
10
from graph_retriever.edges import Edge, IdEdge, MetadataEdge
1✔
11
from graph_retriever.utils import merge
1✔
12
from graph_retriever.utils.batched import batched
1✔
13
from graph_retriever.utils.top_k import top_k
1✔
14
from immutabledict import immutabledict
1✔
15
from typing_extensions import override
1✔
16

17
try:
1✔
18
    from langchain_astradb import AstraDBVectorStore
1✔
19
    from langchain_astradb.vectorstores import AstraDBQueryResult
1✔
UNCOV
20
except (ImportError, ModuleNotFoundError):
×
21
    raise ImportError("please `pip install langchain-astradb`")
×
22

23
try:
1✔
24
    import astrapy
1✔
25
except (ImportError, ModuleNotFoundError):
×
26
    raise ImportError("please `pip install astrapy")
×
27
import httpx
1✔
28
from graph_retriever.adapters import Adapter
1✔
29

30
_EXCEPTIONS_TO_RETRY = (
1✔
31
    httpx.TransportError,
32
    astrapy.exceptions.DataAPIException,
33
)
34
_MAX_RETRIES = 3
1✔
35

36

37
def _extract_queries(edges: set[Edge]) -> tuple[dict[str, Iterable[Any]], set[str]]:
1✔
38
    metadata: dict[str, set[Any]] = {}
1✔
39
    ids: set[str] = set()
1✔
40

41
    for edge in edges:
1✔
42
        if isinstance(edge, MetadataEdge):
1✔
43
            metadata.setdefault(edge.incoming_field, set()).add(edge.value)
1✔
44
        elif isinstance(edge, IdEdge):
1!
45
            ids.add(edge.id)
1✔
46
        else:
47
            raise ValueError(f"Unsupported edge {edge}")
×
48

49
    return (cast(dict[str, Iterable[Any]], metadata), ids)
1✔
50

51

52
def _metadata_queries(
1✔
53
    user_filters: dict[str, Any] | None,
54
    metadata: dict[str, Iterable[Any]] = {},
55
) -> Iterator[dict[str, Any]]:
56
    """
57
    Generate queries for matching all user_filters and any `metadata`.
58

59
    The results of the queries can be merged to produce the results.
60

61
    Results will match at least one metadata value in one of the metadata fields.
62

63
    Results will also match all of the `user_filters`.
64

65
    Parameters
66
    ----------
67
    user_filters :
68
        User filters that all results must match.
69
    metadata :
70
        An item matches the queries if it matches all user filters, and
71
        there exists a `key` such that `metadata[key]` has a non-empty
72
        intersection with the actual values of `item.metadata[key]`.
73

74
    Yields
75
    ------
76
    :
77
        Queries corresponding to `user_filters AND metadata`.
78
    """
79
    if user_filters:
1✔
80

81
        def with_user_filters(filter: dict[str, Any]) -> dict[str, Any]:
1✔
82
            return {"$and": [filter, user_filters]}
1✔
83
    else:
84

85
        def with_user_filters(filter: dict[str, Any]) -> dict[str, Any]:
1✔
86
            return filter
1✔
87

88
    def process_value(v: Any) -> Any:
1✔
89
        if isinstance(v, immutabledict):
1✔
90
            return dict(v)
1✔
91
        else:
92
            return v
1✔
93

94
    for k, v in metadata.items():
1✔
95
        for v_batch in batched(v, 100):
1✔
96
            batch = [process_value(v) for v in v_batch]
1✔
97
            if isinstance(batch[0], dict):
1✔
98
                if len(batch) == 1:
1✔
99
                    yield with_user_filters({k: {"$all": [batch[0]]}})
1✔
100
                else:
101
                    yield with_user_filters(
1✔
102
                        {"$or": [{k: {"$all": [v]}} for v in batch]}
103
                    )
104
            else:
105
                if len(batch) == 1:
1✔
106
                    yield (with_user_filters({k: batch[0]}))
1✔
107
                else:
108
                    yield (with_user_filters({k: {"$in": batch}}))
1✔
109

110

111
async def empty_async_iterable() -> AsyncIterable[AstraDBQueryResult]:
1✔
112
    """Create an empty async iterable."""
113
    if False:
1✔
114
        yield
115

116

117
class AstraAdapter(Adapter):
1✔
118
    """
119
    Adapter for the [AstraDB](https://www.datastax.com/products/datastax-astra) vector store.
120

121
    This class integrates the LangChain AstraDB vector store with the graph
122
    retriever system, providing functionality for similarity search and document
123
    retrieval.
124

125
    Parameters
126
    ----------
127
    vector_store :
128
        The AstraDB vector store instance.
129
    """  # noqa: E501
130

131
    def __init__(self, vector_store: AstraDBVectorStore) -> None:
1✔
132
        self.vector_store = vector_store.copy(
1✔
133
            component_name="langchain_graph_retriever"
134
        )
135

136
    def _build_content(self, result: AstraDBQueryResult) -> Content:
1✔
137
        assert result.embedding is not None
1✔
138
        return Content(
1✔
139
            id=result.id,
140
            content=result.document.page_content,
141
            metadata=result.document.metadata,
142
            embedding=result.embedding,
143
        )
144

145
    def _build_content_iter(
1✔
146
        self, results: Iterable[AstraDBQueryResult]
147
    ) -> Iterable[Content]:
148
        for result in results:
1✔
149
            yield self._build_content(result)
1✔
150

151
    async def _abuild_content_iter(
1✔
152
        self, results: AsyncIterable[AstraDBQueryResult]
153
    ) -> AsyncIterable[Content]:
154
        async for result in results:
1✔
155
            yield self._build_content(result)
1✔
156

157
    @overload
1✔
158
    def _run_query(
1!
159
        self,
160
        *,
161
        n: int,
162
        include_sort_vector: Literal[False] = False,
163
        ids: list[str] | None = None,
164
        filter: dict[str, Any] | None = None,
165
        sort: dict[str, Any] | None = None,
166
    ) -> Iterable[Content]: ...
167

168
    @overload
1✔
169
    def _run_query(
1!
170
        self,
171
        *,
172
        n: int,
173
        include_sort_vector: Literal[True],
174
        ids: list[str] | None = None,
175
        filter: dict[str, Any] | None = None,
176
        sort: dict[str, Any] | None = None,
177
    ) -> tuple[list[float], Iterable[Content]]: ...
178

179
    @backoff.on_exception(backoff.expo, _EXCEPTIONS_TO_RETRY, max_tries=_MAX_RETRIES)
1✔
180
    def _run_query(
1✔
181
        self,
182
        *,
183
        n: int,
184
        ids: list[str] | None = None,
185
        filter: dict[str, Any] | None = None,  # noqa: A002
186
        sort: dict[str, Any] | None = None,
187
        include_sort_vector: bool = False,
188
    ) -> tuple[list[float], Iterable[Content]] | Iterable[Content]:
189
        if include_sort_vector:
1✔
190
            # Work around the fact that `k == 0` is rejected by Astra.
191
            # AstraDBVectorStore has a similar work around for non-vectorize path, but
192
            # we want it to apply in both cases.
193
            query_n = n if n > 0 else 1
1✔
194

195
            query_embedding, results = self.vector_store.run_query(
1✔
196
                n=query_n,
197
                ids=ids,
198
                filter=filter,
199
                sort=sort,
200
                include_sort_vector=True,
201
                include_embeddings=True,
202
                include_similarity=False,
203
            )
204
            assert query_embedding is not None
1✔
205
            if n == 0:
1✔
206
                return query_embedding, self._build_content_iter([])
1✔
207
            return query_embedding, self._build_content_iter(results)
1✔
208
        else:
209
            results = self.vector_store.run_query(
1✔
210
                n=n,
211
                ids=ids,
212
                filter=filter,
213
                sort=sort,
214
                include_sort_vector=False,
215
                include_embeddings=True,
216
                include_similarity=False,
217
            )
218
            return self._build_content_iter(results)
1✔
219

220
    @overload
1✔
221
    async def _arun_query(
1!
222
        self,
223
        *,
224
        n: int,
225
        include_sort_vector: Literal[False] = False,
226
        ids: list[str] | None = None,
227
        filter: dict[str, Any] | None = None,
228
        sort: dict[str, Any] | None = None,
229
    ) -> AsyncIterable[Content]: ...
230

231
    @overload
1✔
232
    async def _arun_query(
1!
233
        self,
234
        *,
235
        n: int,
236
        include_sort_vector: Literal[True],
237
        ids: list[str] | None = None,
238
        filter: dict[str, Any] | None = None,
239
        sort: dict[str, Any] | None = None,
240
    ) -> tuple[list[float], AsyncIterable[Content]]: ...
241

242
    @backoff.on_exception(backoff.expo, _EXCEPTIONS_TO_RETRY, max_tries=_MAX_RETRIES)
1✔
243
    async def _arun_query(
1✔
244
        self,
245
        *,
246
        n: int,
247
        ids: list[str] | None = None,
248
        filter: dict[str, Any] | None = None,  # noqa: A002
249
        sort: dict[str, Any] | None = None,
250
        include_sort_vector: bool = False,
251
    ) -> tuple[list[float], AsyncIterable[Content]] | AsyncIterable[Content]:
252
        if include_sort_vector:
1✔
253
            # Work around the fact that `k == 0` is rejected by Astra.
254
            # AstraDBVectorStore has a similar work around for non-vectorize path, but
255
            # we want it to apply in both cases.
256
            query_n = n if n > 0 else 1
1✔
257

258
            query_embedding, results = await self.vector_store.arun_query(
1✔
259
                n=query_n,
260
                ids=ids,
261
                filter=filter,
262
                sort=sort,
263
                include_sort_vector=True,
264
                include_embeddings=True,
265
                include_similarity=False,
266
            )
267
            assert query_embedding is not None
1✔
268
            if n == 0:
1✔
269
                return query_embedding, self._abuild_content_iter(
1✔
270
                    empty_async_iterable()
271
                )
272
            return query_embedding, self._abuild_content_iter(results)
1✔
273
        else:
274
            results = await self.vector_store.arun_query(
1✔
275
                n=n,
276
                ids=ids,
277
                filter=filter,
278
                sort=sort,
279
                include_sort_vector=False,
280
                include_embeddings=True,
281
                include_similarity=False,
282
            )
283
            return self._abuild_content_iter(results)
1✔
284

285
    def _vector_sort_from_embedding(
1✔
286
        self,
287
        embedding: list[float],
288
    ) -> dict[str, Any]:
289
        return self.vector_store.document_codec.encode_vector_sort(vector=embedding)
1✔
290

291
    def _get_sort_and_optional_embedding(
1✔
292
        self, query: str, k: int
293
    ) -> tuple[None | list[float], dict[str, Any] | None]:
294
        if self.vector_store.document_codec.server_side_embeddings:
1✔
295
            sort = self.vector_store.document_codec.encode_vectorize_sort(query)
1✔
296
            return None, sort
1✔
297
        else:
298
            embedding = self.vector_store._get_safe_embedding().embed_query(query)
1✔
299
            if k == 0:
1✔
300
                return embedding, None  # signal that we should short-circuit
1✔
301
            sort = self._vector_sort_from_embedding(embedding)
1✔
302
            return embedding, sort
1✔
303

304
    @override
1✔
305
    def search_with_embedding(
1✔
306
        self,
307
        query: str,
308
        k: int = 4,
309
        filter: dict[str, str] | None = None,
310
        **kwargs: Any,
311
    ) -> tuple[list[float], list[Content]]:
312
        query_embedding, sort = self._get_sort_and_optional_embedding(query, k)
1✔
313
        if sort is None and query_embedding is not None:
1✔
314
            return query_embedding, []
1✔
315

316
        query_embedding, results = self._run_query(
1✔
317
            n=k, filter=filter, sort=sort, include_sort_vector=True
318
        )
319
        return query_embedding, list(results)
1✔
320

321
    @override
1✔
322
    async def asearch_with_embedding(
1✔
323
        self,
324
        query: str,
325
        k: int = 4,
326
        filter: dict[str, str] | None = None,
327
        **kwargs: Any,
328
    ) -> tuple[list[float], list[Content]]:
329
        query_embedding, sort = self._get_sort_and_optional_embedding(query, k)
1✔
330
        if sort is None and query_embedding is not None:
1✔
331
            return query_embedding, []
1✔
332

333
        query_embedding, results = await self._arun_query(
1✔
334
            n=k, filter=filter, sort=sort, include_sort_vector=True
335
        )
336
        return query_embedding, [r async for r in results]
1✔
337

338
    @override
1✔
339
    def search(
1✔
340
        self,
341
        embedding: list[float],
342
        k: int = 4,
343
        filter: dict[str, str] | None = None,
344
        **kwargs: Any,
345
    ) -> list[Content]:
346
        if k == 0:
1✔
347
            return []
1✔
348
        sort = self._vector_sort_from_embedding(embedding)
1✔
349
        results = self._run_query(n=k, filter=filter, sort=sort)
1✔
350
        return list(results)
1✔
351

352
    @override
1✔
353
    async def asearch(
1✔
354
        self,
355
        embedding: list[float],
356
        k: int = 4,
357
        filter: dict[str, str] | None = None,
358
        **kwargs: Any,
359
    ) -> list[Content]:
360
        if k == 0:
1✔
361
            return []
1✔
362
        sort = self._vector_sort_from_embedding(embedding)
1✔
363
        results = await self._arun_query(n=k, filter=filter, sort=sort)
1✔
364
        return [r async for r in results]
1✔
365

366
    @override
1✔
367
    def get(
1✔
368
        self, ids: Sequence[str], filter: dict[str, Any] | None = None, **kwargs: Any
369
    ) -> list[Content]:
370
        results = self._run_query(n=len(ids), ids=list(ids), filter=filter)
1✔
371
        return list(results)
1✔
372

373
    @override
1✔
374
    async def aget(
1✔
375
        self, ids: Sequence[str], filter: dict[str, Any] | None = None, **kwargs: Any
376
    ) -> list[Content]:
377
        results = await self._arun_query(n=len(ids), ids=list(ids), filter=filter)
1✔
378
        return [r async for r in results]
1✔
379

380
    @override
1✔
381
    def adjacent(
1✔
382
        self,
383
        edges: set[Edge],
384
        query_embedding: list[float],
385
        k: int,
386
        filter: dict[str, Any] | None,
387
        **kwargs: Any,
388
    ) -> Iterable[Content]:
389
        sort = self._vector_sort_from_embedding(query_embedding)
1✔
390
        metadata, ids = _extract_queries(edges)
1✔
391

392
        metadata_queries = _metadata_queries(user_filters=filter, metadata=metadata)
1✔
393

394
        results: dict[str, Content] = {}
1✔
395
        for metadata_query in metadata_queries:
1✔
396
            # TODO: Look at a thread-pool for this.
397
            for result in self._run_query(n=k, filter=metadata_query, sort=sort):
1✔
398
                results[result.id] = result
1✔
399

400
        for id_batch in batched(ids, 100):
1✔
401
            for result in self._run_query(
1✔
402
                n=k, ids=list(id_batch), filter=filter, sort=sort
403
            ):
404
                results[result.id] = result
1✔
405

406
        return top_k(results.values(), embedding=query_embedding, k=k)
1✔
407

408
    @override
1✔
409
    async def aadjacent(
1✔
410
        self,
411
        edges: set[Edge],
412
        query_embedding: list[float],
413
        k: int,
414
        filter: dict[str, Any] | None,
415
        **kwargs: Any,
416
    ) -> Iterable[Content]:
417
        sort = self._vector_sort_from_embedding(query_embedding)
1✔
418
        metadata, ids = _extract_queries(edges)
1✔
419

420
        metadata_queries = _metadata_queries(user_filters=filter, metadata=metadata)
1✔
421

422
        iterables = []
1✔
423
        for metadata_query in metadata_queries:
1✔
424
            iterables.append(
1✔
425
                await self._arun_query(n=k, filter=metadata_query, sort=sort)
426
            )
427
        for id_batch in batched(ids, 100):
1✔
428
            iterables.append(
1✔
429
                await self._arun_query(
430
                    n=k, ids=list(id_batch), filter=filter, sort=sort
431
                )
432
            )
433

434
        iterators: list[AsyncIterator[Content]] = [it.__aiter__() for it in iterables]
1✔
435

436
        results: dict[str, Content] = {}
1✔
437
        async for result in merge.amerge(*iterators):
1✔
438
            results[result.id] = result
1✔
439

440
        return top_k(results.values(), embedding=query_embedding, k=k)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc