• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12899

pending completion
#12899

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">864be0ccc<a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/d3b1af13150e5e3a26410b98f2dc2a19ec2b5368">d3b1af131</a>">d3b1af131</a><a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa"> into 3b6e84588">3b6e84588</a>

79 of 79 new or added lines in 10 files covered. (100.0%)

3059 of 3414 relevant lines covered (89.6%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.7
/datajoint/fetch.py
1
from functools import partial
1✔
2
from pathlib import Path
1✔
3
import logging
1✔
4
import pandas
1✔
5
import itertools
1✔
6
import re
1✔
7
import numpy as np
1✔
8
import uuid
1✔
9
import numbers
1✔
10
from . import blob, hash
1✔
11
from .errors import DataJointError
1✔
12
from .settings import config
1✔
13
from .utils import safe_write
1✔
14

15
logger = logging.getLogger(__name__.split(".")[0])
1✔
16

17

18
class key:
1✔
19
    """
20
    object that allows requesting the primary key as an argument in expression.fetch()
21
    The string "KEY" can be used instead of the class key
22
    """
23

24
    pass
1✔
25

26

27
def is_key(attr):
1✔
28
    return attr is key or attr == "KEY"
1✔
29

30

31
def to_dicts(recarray):
1✔
32
    """convert record array to a dictionaries"""
33
    for rec in recarray:
1✔
34
        yield dict(zip(recarray.dtype.names, rec.tolist()))
1✔
35

36

37
def _get(connection, attr, data, squeeze, download_path):
1✔
38
    """
39
    This function is called for every attribute
40

41
    :param connection: a dj.Connection object
42
    :param attr: attribute name from the table's heading
43
    :param data: literal value fetched from the table
44
    :param squeeze: if True squeeze blobs
45
    :param download_path: for fetches that download data, e.g. attachments
46
    :return: unpacked data
47
    """
48
    if data is None:
1✔
49
        return
1✔
50

51
    extern = (
1✔
52
        connection.schemas[attr.database].external[attr.store]
53
        if attr.is_external
54
        else None
55
    )
56

57
    # apply attribute adapter if present
58
    adapt = attr.adapter.get if attr.adapter else lambda x: x
1✔
59

60
    if attr.is_filepath:
1✔
61
        return adapt(extern.download_filepath(uuid.UUID(bytes=data))[0])
1✔
62

63
    if attr.is_attachment:
1✔
64
        # Steps:
65
        # 1. get the attachment filename
66
        # 2. check if the file already exists at download_path, verify checksum
67
        # 3. if exists and checksum passes then return the local filepath
68
        # 4. Otherwise, download the remote file and return the new filepath
69
        _uuid = uuid.UUID(bytes=data) if attr.is_external else None
1✔
70
        attachment_name = (
1✔
71
            extern.get_attachment_name(_uuid)
72
            if attr.is_external
73
            else data.split(b"\0", 1)[0].decode()
74
        )
75
        local_filepath = Path(download_path) / attachment_name
1✔
76
        if local_filepath.is_file():
1✔
77
            attachment_checksum = (
1✔
78
                _uuid if attr.is_external else hash.uuid_from_buffer(data)
79
            )
80
            if attachment_checksum == hash.uuid_from_file(
1✔
81
                local_filepath, init_string=attachment_name + "\0"
82
            ):
83
                return adapt(
1✔
84
                    str(local_filepath)
85
                )  # checksum passed, no need to download again
86
            # generate the next available alias filename
87
            for n in itertools.count():
1✔
88
                f = local_filepath.parent / (
1✔
89
                    local_filepath.stem + "_%04x" % n + local_filepath.suffix
90
                )
91
                if not f.is_file():
1✔
92
                    local_filepath = f
1✔
93
                    break
1✔
94
                if attachment_checksum == hash.uuid_from_file(
1✔
95
                    f, init_string=attachment_name + "\0"
96
                ):
97
                    return adapt(str(f))  # checksum passed, no need to download again
×
98
        # Save attachment
99
        if attr.is_external:
1✔
100
            extern.download_attachment(_uuid, attachment_name, local_filepath)
1✔
101
        else:
102
            # write from buffer
103
            safe_write(local_filepath, data.split(b"\0", 1)[1])
1✔
104
        return adapt(str(local_filepath))  # download file from remote store
1✔
105

106
    return adapt(
1✔
107
        uuid.UUID(bytes=data)
108
        if attr.uuid
109
        else (
110
            blob.unpack(
111
                extern.get(uuid.UUID(bytes=data)) if attr.is_external else data,
112
                squeeze=squeeze,
113
            )
114
            if attr.is_blob
115
            else data
116
        )
117
    )
118

119

120
def _flatten_attribute_list(primary_key, attrs):
1✔
121
    """
122
    :param primary_key: list of attributes in primary key
123
    :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
124
    :return: generator of attributes where "KEY" is replaces with its component attributes
125
    """
126
    for a in attrs:
1✔
127
        if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
1✔
128
            yield from primary_key
1✔
129
        elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
1✔
130
            yield from (q + " DESC" for q in primary_key)
1✔
131
        else:
132
            yield a
1✔
133

134

135
class Fetch:
1✔
136
    """
137
    A fetch object that handles retrieving elements from the table expression.
138

139
    :param expression: the QueryExpression object to fetch from.
140
    """
141

142
    def __init__(self, expression):
1✔
143
        self._expression = expression
1✔
144

145
    def __call__(
1✔
146
        self,
147
        *attrs,
148
        offset=None,
149
        limit=None,
150
        order_by=None,
151
        format=None,
152
        as_dict=None,
153
        squeeze=False,
154
        download_path="."
155
    ):
156
        """
157
        Fetches the expression results from the database into an np.array or list of dictionaries and
158
        unpacks blob attributes.
159

160
        :param attrs: zero or more attributes to fetch. If not provided, the call will return all attributes of this
161
                        table. If provided, returns tuples with an entry for each attribute.
162
        :param offset: the number of tuples to skip in the returned result
163
        :param limit: the maximum number of tuples to return
164
        :param order_by: a single attribute or the list of attributes to order the results. No ordering should be assumed
165
                        if order_by=None. To reverse the order, add DESC to the attribute name or names: e.g. ("age DESC",
166
                        "frequency") To order by primary key, use "KEY" or "KEY DESC"
167
        :param format: Effective when as_dict=None and when attrs is empty None: default from config['fetch_format'] or
168
                        'array' if not configured "array": use numpy.key_array "frame": output pandas.DataFrame. .
169
        :param as_dict: returns a list of dictionaries instead of a record array. Defaults to False for .fetch() and to
170
                        True for .fetch('KEY')
171
        :param squeeze:  if True, remove extra dimensions from arrays
172
        :param download_path: for fetches that download data, e.g. attachments
173
        :return: the contents of the table in the form of a structured numpy.array or a dict list
174
        """
175
        if order_by is not None:
1✔
176
            # if 'order_by' passed in a string, make into list
177
            if isinstance(order_by, str):
1✔
178
                order_by = [order_by]
1✔
179
            # expand "KEY" or "KEY DESC"
180
            order_by = list(
1✔
181
                _flatten_attribute_list(self._expression.primary_key, order_by)
182
            )
183

184
        attrs_as_dict = as_dict and attrs
1✔
185
        if attrs_as_dict:
1✔
186
            # absorb KEY into attrs and prepare to return attributes as dict (issue #595)
187
            if any(is_key(k) for k in attrs):
1✔
188
                attrs = list(self._expression.primary_key) + [
×
189
                    a for a in attrs if a not in self._expression.primary_key
190
                ]
191
        if as_dict is None:
1✔
192
            as_dict = bool(attrs)  # default to True for "KEY" and False otherwise
1✔
193
        # format should not be specified with attrs or is_dict=True
194
        if format is not None and (as_dict or attrs):
1✔
195
            raise DataJointError(
×
196
                "Cannot specify output format when as_dict=True or "
197
                "when attributes are selected to be fetched separately."
198
            )
199
        if format not in {None, "array", "frame"}:
1✔
200
            raise DataJointError(
×
201
                "Fetch output format must be in "
202
                '{{"array", "frame"}} but "{}" was given'.format(format)
203
            )
204

205
        if not (attrs or as_dict) and format is None:
1✔
206
            format = config["fetch_format"]  # default to array
1✔
207
            if format not in {"array", "frame"}:
1✔
208
                raise DataJointError(
×
209
                    'Invalid entry "{}" in datajoint.config["fetch_format"]: '
210
                    'use "array" or "frame"'.format(format)
211
                )
212

213
        if limit is None and offset is not None:
1✔
214
            logger.warning(
1✔
215
                "Offset set, but no limit. Setting limit to a large number. "
216
                "Consider setting a limit explicitly."
217
            )
218
            limit = 8000000000  # just a very large number to effect no limit
1✔
219

220
        get = partial(
1✔
221
            _get,
222
            self._expression.connection,
223
            squeeze=squeeze,
224
            download_path=download_path,
225
        )
226
        if attrs:  # a list of attributes provided
1✔
227
            attributes = [a for a in attrs if not is_key(a)]
1✔
228
            ret = self._expression.proj(*attributes)
1✔
229
            ret = ret.fetch(
1✔
230
                offset=offset,
231
                limit=limit,
232
                order_by=order_by,
233
                as_dict=False,
234
                squeeze=squeeze,
235
                download_path=download_path,
236
                format="array",
237
            )
238
            if attrs_as_dict:
1✔
239
                ret = [
1✔
240
                    {k: v for k, v in zip(ret.dtype.names, x) if k in attrs}
241
                    for x in ret
242
                ]
243
            else:
244
                return_values = [
1✔
245
                    list(
246
                        (to_dicts if as_dict else lambda x: x)(
247
                            ret[self._expression.primary_key]
248
                        )
249
                    )
250
                    if is_key(attribute)
251
                    else ret[attribute]
252
                    for attribute in attrs
253
                ]
254
                ret = return_values[0] if len(attrs) == 1 else return_values
1✔
255
        else:  # fetch all attributes as a numpy.record_array or pandas.DataFrame
256
            cur = self._expression.cursor(
1✔
257
                as_dict=as_dict, limit=limit, offset=offset, order_by=order_by
258
            )
259
            heading = self._expression.heading
1✔
260

261
            if as_dict:
1✔
262
                ret = [
1✔
263
                    dict(
264
                        (name, get(heading[name], d[name]))
265
                        for name in heading.names_shown
266
                    )
267
                    for d in cur
268
                ]
269
            else:
270
                ret = list(cur.fetchall())
1✔
271
                record_type = (
1✔
272
                    heading.as_dtype_shown
273
                    if not ret
274
                    else np.dtype(
275
                        [
276
                            (
277
                                name,
278
                                type(value),
279
                            )  # use the first element to determine blob type
280
                            if heading[name].is_blob
281
                            and isinstance(value, numbers.Number)
282
                            else (name, heading.as_dtype[name])
283
                            for value, name in zip(ret[0], heading.as_dtype.names)
284
                        ]
285
                    )
286
                )
287
                try:
1✔
288
                    ret = np.array(ret, dtype=record_type)
1✔
289
                except Exception as e:
×
290
                    raise e
×
291

292
                debug_mode = config["loglevel"].lower() == "debug"
1✔
293

294
                for name in heading:
1✔
295
                    if not debug_mode and heading[name].hide:
1✔
296
                        continue
×
297
                    # unpack blobs and externals
298
                    ret[name] = list(map(partial(get, heading[name]), ret[name]))
1✔
299
                if not debug_mode:
1✔
300
                    # NOTE: Now formatted slightly different, shows offsets + item size
301
                    ret = ret[heading.names_shown]
1✔
302
                if format == "frame":
1✔
303
                    ret = pandas.DataFrame(ret).set_index(heading.primary_key)
1✔
304
        return ret
1✔
305

306

307
class Fetch1:
1✔
308
    """
309
    Fetch object for fetching the result of a query yielding one row.
310

311
    :param expression: a query expression to fetch from.
312
    """
313

314
    def __init__(self, expression):
1✔
315
        self._expression = expression
1✔
316

317
    def __call__(self, *attrs, squeeze=False, download_path="."):
1✔
318
        """
319
        Fetches the result of a query expression that yields one entry.
320

321
        If no attributes are specified, returns the result as a dict.
322
        If attributes are specified returns the corresponding results as a tuple.
323

324
        Examples:
325
        d = rel.fetch1()   # as a dictionary
326
        a, b = rel.fetch1('a', 'b')   # as a tuple
327

328
        :params *attrs: attributes to return when expanding into a tuple.
329
                 If attrs is empty, the return result is a dict
330
        :param squeeze:  When true, remove extra dimensions from arrays in attributes
331
        :param download_path: for fetches that download data, e.g. attachments
332
        :return: the one tuple in the table in the form of a dict
333
        """
334
        heading = self._expression.heading
1✔
335

336
        if not attrs:  # fetch all attributes, return as ordered dict
1✔
337
            cur = self._expression.cursor(as_dict=True)
1✔
338
            ret = cur.fetchone()
1✔
339
            if not ret or cur.fetchone():
1✔
340
                raise DataJointError(
1✔
341
                    "fetch1 requires exactly one tuple in the input set."
342
                )
343
            ret = dict(
1✔
344
                (
345
                    name,
346
                    _get(
347
                        self._expression.connection,
348
                        heading[name],
349
                        ret[name],
350
                        squeeze=squeeze,
351
                        download_path=download_path,
352
                    ),
353
                )
354
                for name in heading.names_shown
355
            )
356
        else:  # fetch some attributes, return as tuple
357
            attributes = [a for a in attrs if not is_key(a)]
1✔
358
            result = self._expression.proj(*attributes).fetch(
1✔
359
                squeeze=squeeze, download_path=download_path, format="array"
360
            )
361
            if len(result) != 1:
1✔
362
                raise DataJointError(
1✔
363
                    "fetch1 should only return one tuple. %d tuples found" % len(result)
364
                )
365
            return_values = tuple(
1✔
366
                next(to_dicts(result[self._expression.primary_key]))
367
                if is_key(attribute)
368
                else result[attribute][0]
369
                for attribute in attrs
370
            )
371
            ret = return_values[0] if len(attrs) == 1 else return_values
1✔
372
        return ret
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc