• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12898

pending completion
#12898

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">715ab4055<a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/0a4f193031d8b1e14b09ec62d83c5def3b7421b0">0a4f19303</a>">0a4f19303</a><a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9"> into 3b6e84588">3b6e84588</a>

69 of 69 new or added lines in 9 files covered. (100.0%)

3052 of 3381 relevant lines covered (90.27%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.73
/datajoint/expression.py
1
from itertools import count
1✔
2
import logging
1✔
3
import inspect
1✔
4
import copy
1✔
5
import re
1✔
6
from .settings import config
1✔
7
from .errors import DataJointError
1✔
8
from .fetch import Fetch, Fetch1
1✔
9
from .preview import preview, repr_html
1✔
10
from .condition import (
1✔
11
    AndList,
12
    Not,
13
    make_condition,
14
    assert_join_compatibility,
15
    extract_column_names,
16
    PromiscuousOperand,
17
    translate_attribute,
18
)
19
from .declare import CONSTANT_LITERALS
1✔
20

21
logger = logging.getLogger(__name__.split(".")[0])
1✔
22

23

24
class QueryExpression:
1✔
25
    """
26
    QueryExpression implements query operators to derive new entity set from its input.
27
    A QueryExpression object generates a SELECT statement in SQL.
28
    QueryExpression operators are restrict, join, proj, aggr, and union.
29

30
    A QueryExpression object has a support, a restriction (an AndList), and heading.
31
    Property `heading` (type dj.Heading) contains information about the attributes.
32
    It is loaded from the database and updated by proj.
33

34
    Property `support` is the list of table names or other QueryExpressions to be joined.
35

36
    The restriction is applied first without having access to the attributes generated by the projection.
37
    Then projection is applied by selecting modifying the heading attribute.
38

39
    Application of operators does not always lead to the creation of a subquery.
40
    A subquery is generated when:
41
        1. A restriction is applied on any computed or renamed attributes
42
        2. A projection is applied remapping remapped attributes
43
        3. Subclasses: Join, Aggregation, and Union have additional specific rules.
44
    """
45

46
    _restriction = None
1✔
47
    _restriction_attributes = None
1✔
48
    _left = []  # list of booleans True for left joins, False for inner joins
1✔
49
    _original_heading = None  # heading before projections
1✔
50

51
    # subclasses or instantiators must provide values
52
    _connection = None
1✔
53
    _heading = None
1✔
54
    _support = None
1✔
55

56
    # If the query will be using distinct
57
    _distinct = False
1✔
58

59
    @property
1✔
60
    def connection(self):
61
        """a dj.Connection object"""
62
        assert self._connection is not None
1✔
63
        return self._connection
1✔
64

65
    @property
1✔
66
    def support(self):
67
        """A list of table names or subqueries to from the FROM clause"""
68
        assert self._support is not None
1✔
69
        return self._support
1✔
70

71
    @property
1✔
72
    def heading(self):
73
        """a dj.Heading object, reflects the effects of the projection operator .proj"""
74
        return self._heading
1✔
75

76
    @property
1✔
77
    def original_heading(self):
78
        """a dj.Heading object reflecting the attributes before projection"""
79
        return self._original_heading or self.heading
1✔
80

81
    @property
1✔
82
    def restriction(self):
83
        """a AndList object of restrictions applied to input to produce the result"""
84
        if self._restriction is None:
1✔
85
            self._restriction = AndList()
1✔
86
        return self._restriction
1✔
87

88
    @property
1✔
89
    def restriction_attributes(self):
90
        """the set of attribute names invoked in the WHERE clause"""
91
        if self._restriction_attributes is None:
1✔
92
            self._restriction_attributes = set()
1✔
93
        return self._restriction_attributes
1✔
94

95
    @property
1✔
96
    def primary_key(self):
97
        return self.heading.primary_key
1✔
98

99
    _subquery_alias_count = count()  # count for alias names used in the FROM clause
1✔
100

101
    def from_clause(self):
1✔
102
        support = (
1✔
103
            "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count)
104
            if isinstance(src, QueryExpression)
105
            else src
106
            for src in self.support
107
        )
108
        clause = next(support)
1✔
109
        for s, left in zip(support, self._left):
1✔
110
            clause += " NATURAL{left} JOIN {clause}".format(
1✔
111
                left=" LEFT" if left else "", clause=s
112
            )
113
        return clause
1✔
114

115
    def where_clause(self):
1✔
116
        return (
1✔
117
            ""
118
            if not self.restriction
119
            else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction)
120
        )
121

122
    def make_sql(self, fields=None):
1✔
123
        """
124
        Make the SQL SELECT statement.
125

126
        :param fields: used to explicitly set the select attributes
127
        """
128
        return "SELECT {distinct}{fields} FROM {from_}{where}".format(
1✔
129
            distinct="DISTINCT " if self._distinct else "",
130
            fields=self.heading.as_sql(fields or self.heading.names),
131
            from_=self.from_clause(),
132
            where=self.where_clause(),
133
        )
134

135
    # --------- query operators -----------
136
    def make_subquery(self):
1✔
137
        """create a new SELECT statement where self is the FROM clause"""
138
        result = QueryExpression()
1✔
139
        result._connection = self.connection
1✔
140
        result._support = [self]
1✔
141
        result._heading = self.heading.make_subquery_heading()
1✔
142
        return result
1✔
143

144
    def restrict(self, restriction):
1✔
145
        """
146
        Produces a new expression with the new restriction applied.
147
        rel.restrict(restriction)  is equivalent to  rel & restriction.
148
        rel.restrict(Not(restriction))  is equivalent to  rel - restriction
149
        The primary key of the result is unaffected.
150
        Successive restrictions are combined as logical AND:   r & a & b  is equivalent to r & AndList((a, b))
151
        Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists
152
        (logical disjunction of conditions)
153
        Inverse restriction is accomplished by either using the subtraction operator or the Not class.
154

155
        The expressions in each row equivalent:
156

157
        rel & True                          rel
158
        rel & False                         the empty entity set
159
        rel & 'TRUE'                        rel
160
        rel & 'FALSE'                       the empty entity set
161
        rel - cond                          rel & Not(cond)
162
        rel - 'TRUE'                        rel & False
163
        rel - 'FALSE'                       rel
164
        rel & AndList((cond1,cond2))        rel & cond1 & cond2
165
        rel & AndList()                     rel
166
        rel & [cond1, cond2]                rel & OrList((cond1, cond2))
167
        rel & []                            rel & False
168
        rel & None                          rel & False
169
        rel & any_empty_entity_set          rel & False
170
        rel - AndList((cond1,cond2))        rel & [Not(cond1), Not(cond2)]
171
        rel - [cond1, cond2]                rel & Not(cond1) & Not(cond2)
172
        rel - AndList()                     rel & False
173
        rel - []                            rel
174
        rel - None                          rel
175
        rel - any_empty_entity_set          rel
176

177
        When arg is another QueryExpression, the restriction  rel & arg  restricts rel to elements that match at least
178
        one element in arg (hence arg is treated as an OrList).
179
        Conversely,  rel - arg  restricts rel to elements that do not match any elements in arg.
180
        Two elements match when their common attributes have equal values or when they have no common attributes.
181
        All shared attributes must be in the primary key of either rel or arg or both or an error will be raised.
182

183
        QueryExpression.restrict is the only access point that modifies restrictions. All other operators must
184
        ultimately call restrict()
185

186
        :param restriction: a sequence or an array (treated as OR list), another QueryExpression, an SQL condition
187
        string, or an AndList.
188
        """
189
        attributes = set()
1✔
190
        new_condition = make_condition(self, restriction, attributes)
1✔
191
        if new_condition is True:
1✔
192
            return self  # restriction has no effect, return the same object
1✔
193
        # check that all attributes in condition are present in the query
194
        try:
1✔
195
            raise DataJointError(
1✔
196
                "Attribute `%s` is not found in query."
197
                % next(attr for attr in attributes if attr not in self.heading.names)
198
            )
199
        except StopIteration:
1✔
200
            pass  # all ok
1✔
201
        # If the new condition uses any new attributes, a subquery is required.
202
        # However, Aggregation's HAVING statement works fine with aliased attributes.
203
        need_subquery = isinstance(self, Union) or (
1✔
204
            not isinstance(self, Aggregation) and self.heading.new_attributes
205
        )
206
        if need_subquery:
1✔
207
            result = self.make_subquery()
1✔
208
        else:
209
            result = copy.copy(self)
1✔
210
            result._restriction = AndList(
1✔
211
                self.restriction
212
            )  # copy to preserve the original
213
        result.restriction.append(new_condition)
1✔
214
        result.restriction_attributes.update(attributes)
1✔
215
        return result
1✔
216

217
    def restrict_in_place(self, restriction):
1✔
218
        self.__dict__.update(self.restrict(restriction).__dict__)
×
219

220
    def __and__(self, restriction):
1✔
221
        """
222
        Restriction operator e.g. ``q1 & q2``.
223
        :return: a restricted copy of the input argument
224
        See QueryExpression.restrict for more detail.
225
        """
226
        return self.restrict(restriction)
1✔
227

228
    def __xor__(self, restriction):
1✔
229
        """
230
        Permissive restriction operator ignoring compatibility check  e.g. ``q1 ^ q2``.
231
        """
232
        if inspect.isclass(restriction) and issubclass(restriction, QueryExpression):
1✔
233
            restriction = restriction()
1✔
234
        if isinstance(restriction, Not):
1✔
235
            return self.restrict(Not(PromiscuousOperand(restriction.restriction)))
×
236
        return self.restrict(PromiscuousOperand(restriction))
1✔
237

238
    def __sub__(self, restriction):
1✔
239
        """
240
        Inverted restriction e.g. ``q1 - q2``.
241
        :return: a restricted copy of the input argument
242
        See QueryExpression.restrict for more detail.
243
        """
244
        return self.restrict(Not(restriction))
1✔
245

246
    def __neg__(self):
1✔
247
        """
248
        Convert between restriction and inverted restriction e.g. ``-q1``.
249
        :return: target restriction
250
        See QueryExpression.restrict for more detail.
251
        """
252
        if isinstance(self, Not):
×
253
            return self.restriction
×
254
        return Not(self)
×
255

256
    def __mul__(self, other):
1✔
257
        """
258
        join of query expressions `self` and `other` e.g. ``q1 * q2``.
259
        """
260
        return self.join(other)
1✔
261

262
    def __matmul__(self, other):
1✔
263
        """
264
        Permissive join of query expressions `self` and `other` ignoring compatibility check
265
            e.g. ``q1 @ q2``.
266
        """
267
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
268
            other = other()  # instantiate
1✔
269
        return self.join(other, semantic_check=False)
1✔
270

271
    def join(self, other, semantic_check=True, left=False):
1✔
272
        """
273
        create the joined QueryExpression.
274
        a * b  is short for A.join(B)
275
        a @ b  is short for A.join(B, semantic_check=False)
276
        Additionally, left=True will retain the rows of self, effectively performing a left join.
277
        """
278
        # trigger subqueries if joining on renamed attributes
279
        if isinstance(other, U):
1✔
280
            return other * self
1✔
281
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
282
            other = other()  # instantiate
1✔
283
        if not isinstance(other, QueryExpression):
1✔
284
            raise DataJointError("The argument of join must be a QueryExpression")
×
285
        if semantic_check:
1✔
286
            assert_join_compatibility(self, other)
1✔
287
        join_attributes = set(n for n in self.heading.names if n in other.heading.names)
1✔
288
        # needs subquery if self's FROM clause has common attributes with other's FROM clause
289
        need_subquery1 = need_subquery2 = bool(
1✔
290
            (set(self.original_heading.names) & set(other.original_heading.names))
291
            - join_attributes
292
        )
293
        # need subquery if any of the join attributes are derived
294
        need_subquery1 = (
1✔
295
            need_subquery1
296
            or isinstance(self, Aggregation)
297
            or any(n in self.heading.new_attributes for n in join_attributes)
298
            or isinstance(self, Union)
299
        )
300
        need_subquery2 = (
1✔
301
            need_subquery2
302
            or isinstance(other, Aggregation)
303
            or any(n in other.heading.new_attributes for n in join_attributes)
304
            or isinstance(self, Union)
305
        )
306
        if need_subquery1:
1✔
307
            self = self.make_subquery()
1✔
308
        if need_subquery2:
1✔
309
            other = other.make_subquery()
1✔
310
        result = QueryExpression()
1✔
311
        result._connection = self.connection
1✔
312
        result._support = self.support + other.support
1✔
313
        result._left = self._left + [left] + other._left
1✔
314
        result._heading = self.heading.join(other.heading)
1✔
315
        result._restriction = AndList(self.restriction)
1✔
316
        result._restriction.append(other.restriction)
1✔
317
        result._original_heading = self.original_heading.join(other.original_heading)
1✔
318
        assert len(result.support) == len(result._left) + 1
1✔
319
        return result
1✔
320

321
    def __add__(self, other):
1✔
322
        """union e.g. ``q1 + q2``."""
323
        return Union.create(self, other)
1✔
324

325
    def proj(self, *attributes, **named_attributes):
1✔
326
        """
327
        Projection operator.
328

329
        :param attributes:  attributes to be included in the result. (The primary key is already included).
330
        :param named_attributes: new attributes computed or renamed from existing attributes.
331
        :return: the projected expression.
332
        Primary key attributes cannot be excluded but may be renamed.
333
        If the attribute list contains an Ellipsis ..., then all secondary attributes are included too
334
        Prefixing an attribute name with a dash '-attr' removes the attribute from the list if present.
335
        Keyword arguments can be used to rename attributes as in name='attr', duplicate them as in name='(attr)', or
336
        self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self)
337
        self.proj() -- include only primary key
338
        self.proj('attr1', 'attr2')  -- include primary key and attributes attr1 and attr2
339
        self.proj(..., '-attr1', '-attr2')  -- include all attributes except attr1 and attr2
340
        self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1
341
        self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup'
342
        self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax)
343
        from other attributes available before the projection.
344
        Each attribute name can only be used once.
345
        """
346
        named_attributes = {
1✔
347
            k: translate_attribute(v)[1] for k, v in named_attributes.items()
348
        }
349
        # new attributes in parentheses are included again with the new name without removing original
350
        duplication_pattern = re.compile(
1✔
351
            rf'^\s*\(\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*\)\s*$'
352
        )
353
        # attributes without parentheses renamed
354
        rename_pattern = re.compile(
1✔
355
            rf'^\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*$'
356
        )
357
        replicate_map = {
1✔
358
            k: m.group("name")
359
            for k, m in (
360
                (k, duplication_pattern.match(v)) for k, v in named_attributes.items()
361
            )
362
            if m
363
        }
364
        rename_map = {
1✔
365
            k: m.group("name")
366
            for k, m in (
367
                (k, rename_pattern.match(v)) for k, v in named_attributes.items()
368
            )
369
            if m
370
        }
371
        compute_map = {
1✔
372
            k: v
373
            for k, v in named_attributes.items()
374
            if not duplication_pattern.match(v) and not rename_pattern.match(v)
375
        }
376
        attributes = set(attributes)
1✔
377
        # include primary key
378
        attributes.update((k for k in self.primary_key if k not in rename_map.values()))
1✔
379
        # include all secondary attributes with Ellipsis
380
        if Ellipsis in attributes:
1✔
381
            attributes.discard(Ellipsis)
1✔
382
            attributes.update(
1✔
383
                (
384
                    a
385
                    for a in self.heading.secondary_attributes
386
                    if a not in attributes and a not in rename_map.values()
387
                )
388
            )
389
        try:
1✔
390
            raise DataJointError(
1✔
391
                "%s is not a valid data type for an attribute in .proj"
392
                % next(a for a in attributes if not isinstance(a, str))
393
            )
394
        except StopIteration:
1✔
395
            pass  # normal case
1✔
396
        # remove excluded attributes, specified as `-attr'
397
        excluded = set(a for a in attributes if a.strip().startswith("-"))
1✔
398
        attributes.difference_update(excluded)
1✔
399
        excluded = set(a.lstrip("-").strip() for a in excluded)
1✔
400
        attributes.difference_update(excluded)
1✔
401
        try:
1✔
402
            raise DataJointError(
1✔
403
                "Cannot exclude primary key attribute %s",
404
                next(a for a in excluded if a in self.primary_key),
405
            )
406
        except StopIteration:
1✔
407
            pass  # all ok
1✔
408
        # check that all attributes exist in heading
409
        try:
1✔
410
            raise DataJointError(
1✔
411
                "Attribute `%s` not found."
412
                % next(a for a in attributes if a not in self.heading.names)
413
            )
414
        except StopIteration:
1✔
415
            pass  # all ok
1✔
416

417
        # check that all mentioned names are present in heading
418
        mentions = attributes.union(replicate_map.values()).union(rename_map.values())
1✔
419
        try:
1✔
420
            raise DataJointError(
1✔
421
                "Attribute '%s' not found."
422
                % next(a for a in mentions if not self.heading.names)
423
            )
424
        except StopIteration:
1✔
425
            pass  # all ok
1✔
426

427
        # check that newly created attributes do not clash with any other selected attributes
428
        try:
1✔
429
            raise DataJointError(
1✔
430
                "Attribute `%s` already exists"
431
                % next(
432
                    a
433
                    for a in rename_map
434
                    if a in attributes.union(compute_map).union(replicate_map)
435
                )
436
            )
437
        except StopIteration:
1✔
438
            pass  # all ok
1✔
439
        try:
1✔
440
            raise DataJointError(
1✔
441
                "Attribute `%s` already exists"
442
                % next(
443
                    a
444
                    for a in compute_map
445
                    if a in attributes.union(rename_map).union(replicate_map)
446
                )
447
            )
448
        except StopIteration:
1✔
449
            pass  # all ok
1✔
450
        try:
1✔
451
            raise DataJointError(
1✔
452
                "Attribute `%s` already exists"
453
                % next(
454
                    a
455
                    for a in replicate_map
456
                    if a in attributes.union(rename_map).union(compute_map)
457
                )
458
            )
459
        except StopIteration:
1✔
460
            pass  # all ok
1✔
461

462
        # need a subquery if the projection remaps any remapped attributes
463
        used = set(q for v in compute_map.values() for q in extract_column_names(v))
1✔
464
        used.update(rename_map.values())
1✔
465
        used.update(replicate_map.values())
1✔
466
        used.intersection_update(self.heading.names)
1✔
467
        need_subquery = isinstance(self, Union) or any(
1✔
468
            self.heading[name].attribute_expression is not None for name in used
469
        )
470
        if not need_subquery and self.restriction:
1✔
471
            # need a subquery if the restriction applies to attributes that have been renamed
472
            need_subquery = any(
1✔
473
                name in self.restriction_attributes
474
                for name in self.heading.new_attributes
475
            )
476

477
        result = self.make_subquery() if need_subquery else copy.copy(self)
1✔
478
        result._original_heading = result.original_heading
1✔
479
        result._heading = result.heading.select(
1✔
480
            attributes,
481
            rename_map=dict(**rename_map, **replicate_map),
482
            compute_map=compute_map,
483
        )
484
        return result
1✔
485

486
    def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes):
1✔
487
        """
488
        Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
489
        has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
490

491
        :param group:  The query expression to be aggregated.
492
        :param keep_all_rows: True=keep all the rows from self. False=keep only rows that match entries in group.
493
        :param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
494
        :return: The derived query expression
495
        """
496
        if Ellipsis in attributes:
1✔
497
            # expand ellipsis to include only attributes from the left table
498
            attributes = set(attributes)
1✔
499
            attributes.discard(Ellipsis)
1✔
500
            attributes.update(self.heading.secondary_attributes)
1✔
501
        return Aggregation.create(self, group=group, keep_all_rows=keep_all_rows).proj(
1✔
502
            *attributes, **named_attributes
503
        )
504

505
    aggregate = aggr  # alias for aggr
1✔
506

507
    # ---------- Fetch operators --------------------
508
    @property
1✔
509
    def fetch1(self):
510
        return Fetch1(self)
1✔
511

512
    @property
1✔
513
    def fetch(self):
514
        return Fetch(self)
1✔
515

516
    def head(self, limit=25, **fetch_kwargs):
1✔
517
        """
518
        shortcut to fetch the first few entries from query expression.
519
        Equivalent to fetch(order_by="KEY", limit=25)
520

521
        :param limit:  number of entries
522
        :param fetch_kwargs: kwargs for fetch
523
        :return: query result
524
        """
525
        return self.fetch(order_by="KEY", limit=limit, **fetch_kwargs)
1✔
526

527
    def tail(self, limit=25, **fetch_kwargs):
1✔
528
        """
529
        shortcut to fetch the last few entries from query expression.
530
        Equivalent to fetch(order_by="KEY DESC", limit=25)[::-1]
531

532
        :param limit:  number of entries
533
        :param fetch_kwargs: kwargs for fetch
534
        :return: query result
535
        """
536
        return self.fetch(order_by="KEY DESC", limit=limit, **fetch_kwargs)[::-1]
1✔
537

538
    def __len__(self):
1✔
539
        """:return: number of elements in the result set e.g. ``len(q1)``."""
540
        return self.connection.query(
1✔
541
            "SELECT {select_} FROM {from_}{where}".format(
542
                select_=(
543
                    "count(*)"
544
                    if any(self._left)
545
                    else "count(DISTINCT {fields})".format(
546
                        fields=self.heading.as_sql(
547
                            self.primary_key, include_aliases=False
548
                        )
549
                    )
550
                ),
551
                from_=self.from_clause(),
552
                where=self.where_clause(),
553
            )
554
        ).fetchone()[0]
555

556
    def __bool__(self):
1✔
557
        """
558
        :return: True if the result is not empty. Equivalent to len(self) > 0 but often
559
            faster e.g. ``bool(q1)``.
560
        """
561
        return bool(
1✔
562
            self.connection.query(
563
                "SELECT EXISTS(SELECT 1 FROM {from_}{where})".format(
564
                    from_=self.from_clause(), where=self.where_clause()
565
                )
566
            ).fetchone()[0]
567
        )
568

569
    def __contains__(self, item):
1✔
570
        """
571
        returns True if the restriction in item matches any entries in self
572
            e.g. ``restriction in q1``.
573

574
        :param item: any restriction
575
        (item in query_expression) is equivalent to bool(query_expression & item) but may be
576
        executed more efficiently.
577
        """
578
        return bool(self & item)  # May be optimized e.g. using an EXISTS query
1✔
579

580
    def __iter__(self):
1✔
581
        """
582
        returns an iterator-compatible QueryExpression object e.g. ``iter(q1)``.
583

584
        :param self: iterator-compatible QueryExpression object
585
        """
586
        self._iter_only_key = all(v.in_key for v in self.heading.attributes.values())
1✔
587
        self._iter_keys = self.fetch("KEY")
1✔
588
        return self
1✔
589

590
    def __next__(self):
1✔
591
        """
592
        returns the next record on an iterator-compatible QueryExpression object
593
            e.g. ``next(q1)``.
594

595
        :param self: A query expression
596
        :type self: :class:`QueryExpression`
597
        :rtype: dict
598
        """
599
        try:
1✔
600
            key = self._iter_keys.pop(0)
1✔
601
        except AttributeError:
1✔
602
            # self._iter_keys is missing because __iter__ has not been called.
603
            raise TypeError(
×
604
                "A QueryExpression object is not an iterator. "
605
                "Use iter(obj) to create an iterator."
606
            )
607
        except IndexError:
1✔
608
            raise StopIteration
1✔
609
        else:
610
            if self._iter_only_key:
1✔
611
                return key
×
612
            else:
613
                try:
1✔
614
                    return (self & key).fetch1()
1✔
615
                except DataJointError:
×
616
                    # The data may have been deleted since the moment the keys were fetched
617
                    # -- move on to next entry.
618
                    return next(self)
×
619

620
    def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
1✔
621
        """
622
        See expression.fetch() for input description.
623
        :return: query cursor
624
        """
625
        if offset and limit is None:
1✔
626
            raise DataJointError("limit is required when offset is set")
×
627
        sql = self.make_sql()
1✔
628
        if order_by is not None:
1✔
629
            sql += " ORDER BY " + ", ".join(order_by)
1✔
630
        if limit is not None:
1✔
631
            sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "")
1✔
632
        logger.debug(sql)
1✔
633
        return self.connection.query(sql, as_dict=as_dict)
1✔
634

635
    def __repr__(self):
1✔
636
        """
637
        returns the string representation of a QueryExpression object e.g. ``str(q1)``.
638

639
        :param self: A query expression
640
        :type self: :class:`QueryExpression`
641
        :rtype: str
642
        """
643
        return (
1✔
644
            super().__repr__()
645
            if config["loglevel"].lower() == "debug"
646
            else self.preview()
647
        )
648

649
    def preview(self, limit=None, width=None):
1✔
650
        """:return: a string of preview of the contents of the query."""
651
        return preview(self, limit, width)
1✔
652

653
    def _repr_html_(self):
1✔
654
        """:return: HTML to display table in Jupyter notebook."""
655
        return repr_html(self)
1✔
656

657

658
class Aggregation(QueryExpression):
1✔
659
    """
660
    Aggregation.create(arg, group, comp1='calc1', ..., compn='calcn')  yields an entity set
661
    with primary key from arg.
662
    The computed arguments comp1, ..., compn use aggregation calculations on the attributes of
663
    group or simple projections and calculations on the attributes of arg.
664
    Aggregation is used QueryExpression.aggr and U.aggr.
665
    Aggregation is a private class in DataJoint, not exposed to users.
666
    """
667

668
    _left_restrict = None  # the pre-GROUP BY conditions for the WHERE clause
1✔
669
    _subquery_alias_count = count()
1✔
670

671
    @classmethod
1✔
672
    def create(cls, arg, group, keep_all_rows=False):
1✔
673
        if inspect.isclass(group) and issubclass(group, QueryExpression):
1✔
674
            group = group()  # instantiate if a class
1✔
675
        assert isinstance(group, QueryExpression)
1✔
676
        if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes:
1✔
677
            group = group.make_subquery()  # subquery if left joining a join
1✔
678
        join = arg.join(group, left=keep_all_rows)  # reuse the join logic
1✔
679
        result = cls()
1✔
680
        result._connection = join.connection
1✔
681
        result._heading = join.heading.set_primary_key(
1✔
682
            arg.primary_key
683
        )  # use left operand's primary key
684
        result._support = join.support
1✔
685
        result._left = join._left
1✔
686
        result._left_restrict = join.restriction  # WHERE clause applied before GROUP BY
1✔
687
        result._grouping_attributes = result.primary_key
1✔
688

689
        return result
1✔
690

691
    def where_clause(self):
1✔
692
        return (
1✔
693
            ""
694
            if not self._left_restrict
695
            else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict)
696
        )
697

698
    def make_sql(self, fields=None):
1✔
699
        fields = self.heading.as_sql(fields or self.heading.names)
1✔
700
        assert self._grouping_attributes or not self.restriction
1✔
701
        distinct = set(self.heading.names) == set(self.primary_key)
1✔
702
        return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
1✔
703
            distinct="DISTINCT " if distinct else "",
704
            fields=fields,
705
            from_=self.from_clause(),
706
            where=self.where_clause(),
707
            group_by=""
708
            if not self.primary_key
709
            else (
710
                " GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
711
                + (
712
                    ""
713
                    if not self.restriction
714
                    else " HAVING (%s)" % ")AND(".join(self.restriction)
715
                )
716
            ),
717
        )
718

719
    def __len__(self):
1✔
720
        return self.connection.query(
1✔
721
            "SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
722
                subquery=self.make_sql(), alias=next(self._subquery_alias_count)
723
            )
724
        ).fetchone()[0]
725

726
    def __bool__(self):
1✔
727
        return bool(
1✔
728
            self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))
729
        )
730

731

732
class Union(QueryExpression):
1✔
733
    """
734
    Union is the private DataJoint class that implements the union operator.
735
    """
736

737
    __count = count()
1✔
738

739
    @classmethod
1✔
740
    def create(cls, arg1, arg2):
741
        if inspect.isclass(arg2) and issubclass(arg2, QueryExpression):
1✔
742
            arg2 = arg2()  # instantiate if a class
1✔
743
        if not isinstance(arg2, QueryExpression):
1✔
744
            raise DataJointError(
×
745
                "A QueryExpression can only be unioned with another QueryExpression"
746
            )
747
        if arg1.connection != arg2.connection:
1✔
748
            raise DataJointError(
×
749
                "Cannot operate on QueryExpressions originating from different connections."
750
            )
751
        if set(arg1.primary_key) != set(arg2.primary_key):
1✔
752
            raise DataJointError(
×
753
                "The operands of a union must share the same primary key."
754
            )
755
        if set(arg1.heading.secondary_attributes) & set(
1✔
756
            arg2.heading.secondary_attributes
757
        ):
758
            raise DataJointError(
×
759
                "The operands of a union must not share any secondary attributes."
760
            )
761
        result = cls()
1✔
762
        result._connection = arg1.connection
1✔
763
        result._heading = arg1.heading.join(arg2.heading)
1✔
764
        result._support = [arg1, arg2]
1✔
765
        return result
1✔
766

767
    def make_sql(self):
1✔
768
        arg1, arg2 = self._support
1✔
769
        if (
1✔
770
            not arg1.heading.secondary_attributes
771
            and not arg2.heading.secondary_attributes
772
        ):
773
            # no secondary attributes: use UNION DISTINCT
774
            fields = arg1.primary_key
1✔
775
            return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format(
1✔
776
                sql1=arg1.make_sql()
777
                if isinstance(arg1, Union)
778
                else arg1.make_sql(fields),
779
                sql2=arg2.make_sql()
780
                if isinstance(arg2, Union)
781
                else arg2.make_sql(fields),
782
                alias=next(self.__count),
783
            )
784
        # with secondary attributes, use union of left join with antijoin
785
        fields = self.heading.names
1✔
786
        sql1 = arg1.join(arg2, left=True).make_sql(fields)
1✔
787
        sql2 = (
1✔
788
            (arg2 - arg1)
789
            .proj(..., **{k: "NULL" for k in arg1.heading.secondary_attributes})
790
            .make_sql(fields)
791
        )
792
        return "({sql1})  UNION ({sql2})".format(sql1=sql1, sql2=sql2)
1✔
793

794
    def from_clause(self):
1✔
795
        """The union does not use a FROM clause"""
796
        assert False
×
797

798
    def where_clause(self):
1✔
799
        """The union does not use a WHERE clause"""
800
        assert False
×
801

802
    def __len__(self):
1✔
803
        return self.connection.query(
1✔
804
            "SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
805
                subquery=self.make_sql(),
806
                alias=next(QueryExpression._subquery_alias_count),
807
            )
808
        ).fetchone()[0]
809

810
    def __bool__(self):
1✔
811
        return bool(
×
812
            self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))
813
        )
814

815

816
class U:
1✔
817
    """
818
    dj.U objects are the universal sets representing all possible values of their attributes.
819
    dj.U objects cannot be queried on their own but are useful for forming some queries.
820
    dj.U('attr1', ..., 'attrn') represents the universal set with the primary key attributes attr1 ... attrn.
821
    The universal set is the set of all possible combinations of values of the attributes.
822
    Without any attributes, dj.U() represents the set with one element that has no attributes.
823

824
    Restriction:
825

826
    dj.U can be used to enumerate unique combinations of values of attributes from other expressions.
827

828
    The following expression yields all unique combinations of contrast and brightness found in the `stimulus` set:
829

830
    >>> dj.U('contrast', 'brightness') & stimulus
831

832
    Aggregation:
833

834
    In aggregation, dj.U is used for summary calculation over an entire set:
835

836
    The following expression yields one element with one attribute `s` containing the total number of elements in
837
    query expression `expr`:
838

839
    >>> dj.U().aggr(expr, n='count(*)')
840

841
    The following expressions both yield one element containing the number `n` of distinct values of attribute `attr` in
842
    query expressio `expr`.
843

844
    >>> dj.U().aggr(expr, n='count(distinct attr)')
845
    >>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
846

847
    The following expression yields one element and one attribute `s` containing the sum of values of attribute `attr`
848
    over entire result set of expression `expr`:
849

850
    >>> dj.U().aggr(expr, s='sum(attr)')
851

852
    The following expression yields the set of all unique combinations of attributes `attr1`, `attr2` and the number of
853
    their occurrences in the result set of query expression `expr`.
854

855
    >>> dj.U(attr1,attr2).aggr(expr, n='count(*)')
856

857
    Joins:
858

859
    If expression `expr` has attributes 'attr1' and 'attr2', then expr * dj.U('attr1','attr2') yields the same result
860
    as `expr` but `attr1` and `attr2` are promoted to the the primary key.  This is useful for producing a join on
861
    non-primary key attributes.
862
    For example, if `attr` is in both expr1 and expr2 but not in their primary keys, then expr1 * expr2 will throw
863
    an error because in most cases, it does not make sense to join on non-primary key attributes and users must first
864
    rename `attr` in one of the operands.  The expression dj.U('attr') * rel1 * rel2 overrides this constraint.
865
    """
866

867
    def __init__(self, *primary_key):
1✔
868
        self._primary_key = primary_key
1✔
869

870
    @property
1✔
871
    def primary_key(self):
872
        return self._primary_key
1✔
873

874
    def __and__(self, other):
1✔
875
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
876
            other = other()  # instantiate if a class
×
877
        if not isinstance(other, QueryExpression):
1✔
878
            raise DataJointError("Set U can only be restricted with a QueryExpression.")
1✔
879
        result = copy.copy(other)
1✔
880
        result._distinct = True
1✔
881
        result._heading = result.heading.set_primary_key(self.primary_key)
1✔
882
        result = result.proj()
1✔
883
        return result
1✔
884

885
    def join(self, other, left=False):
1✔
886
        """
887
        Joining U with a query expression has the effect of promoting the attributes of U to
888
        the primary key of the other query expression.
889

890
        :param other: the other query expression to join with.
891
        :param left: ignored. dj.U always acts as if left=False
892
        :return: a copy of the other query expression with the primary key extended.
893
        """
894
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
895
            other = other()  # instantiate if a class
×
896
        if not isinstance(other, QueryExpression):
1✔
897
            raise DataJointError("Set U can only be joined with a QueryExpression.")
1✔
898
        try:
1✔
899
            raise DataJointError(
1✔
900
                "Attribute `%s` not found"
901
                % next(k for k in self.primary_key if k not in other.heading.names)
902
            )
903
        except StopIteration:
1✔
904
            pass  # all ok
1✔
905
        result = copy.copy(other)
1✔
906
        result._heading = result.heading.set_primary_key(
1✔
907
            other.primary_key
908
            + [k for k in self.primary_key if k not in other.primary_key]
909
        )
910
        return result
1✔
911

912
    def __mul__(self, other):
1✔
913
        """shorthand for join"""
914
        return self.join(other)
1✔
915

916
    def aggr(self, group, **named_attributes):
1✔
917
        """
918
        Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
919
        has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
920

921
        :param group:  The query expression to be aggregated.
922
        :param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
923
        :return: The derived query expression
924
        """
925
        if named_attributes.get("keep_all_rows", False):
1✔
926
            raise DataJointError(
×
927
                "Cannot set keep_all_rows=True when aggregating on a universal set."
928
            )
929
        return Aggregation.create(self, group=group, keep_all_rows=False).proj(
1✔
930
            **named_attributes
931
        )
932

933
    aggregate = aggr  # alias for aggr
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc