• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

syrusakbary / aiodataloader / 18587929429

17 Oct 2025 09:03AM UTC coverage: 91.87% (-0.6%) from 92.5%
18587929429

Pull #54

github

web-flow
Merge 941d87833 into ffad8a1b1
Pull Request #54: Import `iscoroutinefunction()` from the `inspect` module on Python >= 3.14

2 of 3 new or added lines in 1 file covered. (66.67%)

1 existing line in 1 file now uncovered.

113 of 123 relevant lines covered (91.87%)

5.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.87
/aiodataloader/__init__.py
1
import sys
6✔
2
from asyncio import (
6✔
3
    AbstractEventLoop,
4
    ensure_future,
5
    Future,
6
    gather,
7
    get_event_loop,
8
    iscoroutine,
9
)
10
from collections import namedtuple
6✔
11
from functools import partial
6✔
12
from typing import (
6✔
13
    Any,
14
    Callable,
15
    Coroutine,
16
    Generic,
17
    Iterable,
18
    Iterator,
19
    List,
20
    MutableMapping,
21
    Optional,
22
    TypeVar,
23
    Union,
24
)
25

26
if sys.version_info >= (3, 14):
6✔
NEW
27
    from inspect import iscoroutinefunction
×
28
else:
29
    from asyncio import iscoroutinefunction
6✔
30

31
if sys.version_info >= (3, 10):
6✔
32
    from typing import TypeGuard
4✔
33
else:
UNCOV
34
    from typing_extensions import TypeGuard
2✔
35

36
__version__ = "0.4.2"
6✔
37

38
KeyT = TypeVar("KeyT")
6✔
39
ReturnT = TypeVar("ReturnT")
6✔
40
CacheKeyT = TypeVar("CacheKeyT")
6✔
41
DataLoaderT = TypeVar("DataLoaderT", bound="DataLoader[Any, Any]")
6✔
42
T = TypeVar("T")
6✔
43

44

45
def iscoroutinefunctionorpartial(
6✔
46
    fn: Union[Callable[..., ReturnT], "partial[ReturnT]"],
47
) -> TypeGuard[Callable[..., Coroutine[Any, Any, ReturnT]]]:
48
    return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
6✔
49

50

51
Loader = namedtuple("Loader", "key,future")
6✔
52

53

54
class DataLoader(Generic[KeyT, ReturnT]):
6✔
55
    batch: bool = True
6✔
56
    max_batch_size: Optional[int] = None
6✔
57
    cache: Optional[bool] = True
6✔
58

59
    def __init__(
6✔
60
        self,
61
        batch_load_fn: Optional[
62
            Callable[[List[KeyT]], Coroutine[Any, Any, List[ReturnT]]]
63
        ] = None,
64
        batch: Optional[bool] = None,
65
        max_batch_size: Optional[int] = None,
66
        cache: Optional[bool] = None,
67
        get_cache_key: Optional[Callable[[KeyT], Union[CacheKeyT, KeyT]]] = None,
68
        cache_map: Optional[
69
            MutableMapping[Union[CacheKeyT, KeyT], "Future[ReturnT]"]
70
        ] = None,
71
        loop: Optional[AbstractEventLoop] = None,
72
    ):
73
        self.loop = loop or get_event_loop()
6✔
74

75
        if batch_load_fn is not None:
6✔
76
            self.batch_load_fn = batch_load_fn
6✔
77

78
        assert iscoroutinefunctionorpartial(
6✔
79
            self.batch_load_fn
80
        ), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
81

82
        if not callable(self.batch_load_fn):
6✔
83
            raise TypeError(
×
84
                (
85
                    "DataLoader must have a batch_load_fn which accepts "
86
                    "Iterable<key> and returns Future<Iterable<value>>, but got: {}."
87
                ).format(batch_load_fn)
88
            )
89

90
        if batch is not None:
6✔
91
            self.batch = batch
×
92

93
        if max_batch_size is not None:
6✔
94
            self.max_batch_size = max_batch_size
6✔
95

96
        if cache is not None:
6✔
97
            self.cache = cache
×
98

99
        if get_cache_key is not None:
6✔
100
            self.get_cache_key = get_cache_key
×
101
        if not hasattr(self, "get_cache_key"):
6✔
102
            self.get_cache_key = lambda x: x
6✔
103

104
        self._cache = cache_map if cache_map is not None else {}
6✔
105
        self._queue: List[Loader] = []
6✔
106

107
    def load(self, key: KeyT) -> "Future[ReturnT]":
6✔
108
        """
109
        Loads a key, returning a `Future` for the value represented by that key.
110
        """
111
        if key is None:
6✔
112
            raise TypeError(
6✔
113
                (
114
                    "The loader.load() function must be called with a value, "
115
                    "but got: {}."
116
                ).format(key)
117
            )
118

119
        cache_key = self.get_cache_key(key)
6✔
120

121
        # If caching and there is a cache-hit, return cached Future.
122
        if self.cache:
6✔
123
            cached_result = self._cache.get(cache_key)
6✔
124
            if cached_result:
6✔
125
                return cached_result
6✔
126

127
        # Otherwise, produce a new Future for this value.
128
        future = self.loop.create_future()
6✔
129
        # If caching, cache this Future.
130
        if self.cache:
6✔
131
            self._cache[cache_key] = future
6✔
132

133
        self.do_resolve_reject(key, future)
6✔
134
        return future
6✔
135

136
    def do_resolve_reject(self, key: KeyT, future: "Future[ReturnT]") -> None:
6✔
137
        # Enqueue this Future to be dispatched.
138
        self._queue.append(Loader(key=key, future=future))
6✔
139
        # Determine if a dispatch of this queue should be scheduled.
140
        # A single dispatch should be scheduled per queue at the time when the
141
        # queue changes from "empty" to "full".
142
        if len(self._queue) == 1:
6✔
143
            if self.batch:
6✔
144
                # If batching, schedule a task to dispatch the queue.
145
                enqueue_post_future_job(self.loop, self)
6✔
146
            else:
147
                # Otherwise dispatch the (queue of one) immediately.
148
                dispatch_queue(self)
×
149

150
    def load_many(self, keys: Iterable[KeyT]) -> "Future[List[ReturnT]]":
6✔
151
        """
152
        Loads multiple keys, returning a list of values
153

154
        >>> a, b = await my_loader.load_many([ 'a', 'b' ])
155

156
        This is equivalent to the more verbose:
157

158
        >>> a, b = await gather(
159
        >>>    my_loader.load('a'),
160
        >>>    my_loader.load('b')
161
        >>> )
162
        """
163
        if not isinstance(keys, Iterable):
6✔
164
            raise TypeError(
×
165
                (
166
                    "The loader.load_many() function must be called with Iterable<key> "
167
                    "but got: {}."
168
                ).format(keys)
169
            )
170

171
        return gather(*[self.load(key) for key in keys])
6✔
172

173
    def clear(self: DataLoaderT, key: KeyT) -> DataLoaderT:
6✔
174
        """
175
        Clears the value at `key` from the cache, if it exists. Returns itself for
176
        method chaining.
177
        """
178
        cache_key = self.get_cache_key(key)
6✔
179
        self._cache.pop(cache_key, None)
6✔
180
        return self
6✔
181

182
    def clear_all(self: DataLoaderT) -> DataLoaderT:
6✔
183
        """
184
        Clears the entire cache. To be used when some event results in unknown
185
        invalidations across this particular `DataLoader`. Returns itself for
186
        method chaining.
187
        """
188
        self._cache.clear()
6✔
189
        return self
6✔
190

191
    def prime(self: DataLoaderT, key: KeyT, value: ReturnT) -> DataLoaderT:
6✔
192
        """
193
        Adds the provied key and value to the cache. If the key already exists, no
194
        change is made. Returns itself for method chaining.
195
        """
196
        cache_key = self.get_cache_key(key)
6✔
197

198
        # Only add the key if it does not already exist.
199
        if cache_key not in self._cache:
6✔
200
            # Cache a rejected future if the value is an Error, in order to match
201
            # the behavior of load(key).
202
            future = self.loop.create_future()
6✔
203
            if not future.cancelled():
6✔
204
                if isinstance(value, Exception):
6✔
205
                    future.set_exception(value)
6✔
206
                else:
207
                    future.set_result(value)
6✔
208

209
            self._cache[cache_key] = future
6✔
210

211
        return self
6✔
212

213

214
def enqueue_post_future_job(
6✔
215
    loop: AbstractEventLoop, loader: DataLoader[Any, Any]
216
) -> None:
217
    async def dispatch() -> None:
6✔
218
        dispatch_queue(loader)
6✔
219

220
    loop.call_soon(ensure_future, dispatch())
6✔
221

222

223
def get_chunks(iterable_obj: List[T], chunk_size: int = 1) -> Iterator[List[T]]:
6✔
224
    chunk_size = max(1, chunk_size)
6✔
225
    return (
6✔
226
        iterable_obj[i : i + chunk_size]
227
        for i in range(0, len(iterable_obj), chunk_size)
228
    )
229

230

231
def dispatch_queue(loader: DataLoader[Any, Any]) -> None:
6✔
232
    """
233
    Given the current state of a Loader instance, perform a batch load
234
    from its current queue.
235
    """
236
    # Take the current loader queue, replacing it with an empty queue.
237
    queue = loader._queue
6✔
238
    loader._queue = []
6✔
239

240
    # If a max_batch_size was provided and the queue is longer, then segment the
241
    # queue into multiple batches, otherwise treat the queue as a single batch.
242
    max_batch_size = loader.max_batch_size
6✔
243

244
    if max_batch_size and max_batch_size < len(queue):
6✔
245
        chunks = get_chunks(queue, max_batch_size)
6✔
246
        for chunk in chunks:
6✔
247
            ensure_future(dispatch_queue_batch(loader, chunk))
6✔
248
    else:
249
        ensure_future(dispatch_queue_batch(loader, queue))
6✔
250

251

252
async def dispatch_queue_batch(
6✔
253
    loader: DataLoader[Any, Any], queue: List[Loader]
254
) -> None:
255
    # Collect all keys to be loaded in this dispatch
256
    keys = [ql.key for ql in queue]
6✔
257

258
    # Call the provided batch_load_fn for this loader with the loader queue's keys.
259
    batch_future = loader.batch_load_fn(keys)
6✔
260

261
    # Assert the expected response from batch_load_fn
262
    if not batch_future or not iscoroutine(batch_future):
6✔
263
        return failed_dispatch(
×
264
            loader,
265
            queue,
266
            TypeError(
267
                (
268
                    "DataLoader must be constructed with a function which accepts "
269
                    "Iterable<key> and returns Future<Iterable<value>>, but the "
270
                    "function did not return a Coroutine: {}."
271
                ).format(batch_future)
272
            ),
273
        )
274

275
    try:
6✔
276
        values = await batch_future
6✔
277
        if not isinstance(values, Iterable):
6✔
278
            raise TypeError(
×
279
                (
280
                    "DataLoader must be constructed with a function which accepts "
281
                    "Iterable<key> and returns Future<Iterable<value>>, but the "
282
                    "function did not return a Future of a Iterable: {}."
283
                ).format(values)
284
            )
285

286
        values = list(values)
6✔
287
        if len(values) != len(keys):
6✔
288
            raise TypeError(
×
289
                (
290
                    "DataLoader must be constructed with a function which accepts "
291
                    "Iterable<key> and returns Future<Iterable<value>>, but the "
292
                    "function did not return a Future of a Iterable with the same "
293
                    "length as the Iterable of keys."
294
                    "\n\nKeys:\n{}"
295
                    "\n\nValues:\n{}"
296
                ).format(keys, values)
297
            )
298

299
        # Step through the values, resolving or rejecting each Future in the
300
        # loaded queue.
301
        for ql, value in zip(queue, values):
6✔
302
            if not ql.future.cancelled():
6✔
303
                if isinstance(value, Exception):
6✔
304
                    ql.future.set_exception(value)
6✔
305
                else:
306
                    ql.future.set_result(value)
6✔
307

308
    except Exception as e:
6✔
309
        return failed_dispatch(loader, queue, e)
6✔
310

311

312
def failed_dispatch(
6✔
313
    loader: DataLoader[Any, Any], queue: List[Loader], error: Exception
314
) -> None:
315
    """
316
    Do not cache individual loads if the entire batch dispatch fails,
317
    but still reject each request so they do not hang.
318
    """
319
    for ql in queue:
6✔
320
        loader.clear(ql.key)
6✔
321
        if not ql.future.cancelled():
6✔
322
            ql.future.set_exception(error)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc