• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12899

pending completion
#12899

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">864be0ccc<a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/d3b1af13150e5e3a26410b98f2dc2a19ec2b5368">d3b1af131</a>">d3b1af131</a><a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa"> into 3b6e84588">3b6e84588</a>

79 of 79 new or added lines in 10 files covered. (100.0%)

3059 of 3414 relevant lines covered (89.6%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.71
/datajoint/expression.py
1
from itertools import count
1✔
2
import logging
1✔
3
import inspect
1✔
4
import copy
1✔
5
import re
1✔
6
from .settings import config
1✔
7
from .errors import DataJointError
1✔
8
from .fetch import Fetch, Fetch1
1✔
9
from .preview import preview, repr_html
1✔
10
from .condition import (
1✔
11
    AndList,
12
    Not,
13
    make_condition,
14
    assert_join_compatibility,
15
    extract_column_names,
16
    PromiscuousOperand,
17
)
18
from .declare import CONSTANT_LITERALS
1✔
19

20
logger = logging.getLogger(__name__.split(".")[0])
1✔
21

22

23
class QueryExpression:
1✔
24
    """
25
    QueryExpression implements query operators to derive new entity set from its input.
26
    A QueryExpression object generates a SELECT statement in SQL.
27
    QueryExpression operators are restrict, join, proj, aggr, and union.
28

29
    A QueryExpression object has a support, a restriction (an AndList), and heading.
30
    Property `heading` (type dj.Heading) contains information about the attributes.
31
    It is loaded from the database and updated by proj.
32

33
    Property `support` is the list of table names or other QueryExpressions to be joined.
34

35
    The restriction is applied first without having access to the attributes generated by the projection.
36
    Then projection is applied by selecting modifying the heading attribute.
37

38
    Application of operators does not always lead to the creation of a subquery.
39
    A subquery is generated when:
40
        1. A restriction is applied on any computed or renamed attributes
41
        2. A projection is applied remapping remapped attributes
42
        3. Subclasses: Join, Aggregation, and Union have additional specific rules.
43
    """
44

45
    _restriction = None
1✔
46
    _restriction_attributes = None
1✔
47
    _left = []  # list of booleans True for left joins, False for inner joins
1✔
48
    _original_heading = None  # heading before projections
1✔
49

50
    # subclasses or instantiators must provide values
51
    _connection = None
1✔
52
    _heading = None
1✔
53
    _support = None
1✔
54

55
    # If the query will be using distinct
56
    _distinct = False
1✔
57

58
    @property
1✔
59
    def connection(self):
60
        """a dj.Connection object"""
61
        assert self._connection is not None
1✔
62
        return self._connection
1✔
63

64
    @property
1✔
65
    def support(self):
66
        """A list of table names or subqueries to from the FROM clause"""
67
        assert self._support is not None
1✔
68
        return self._support
1✔
69

70
    @property
1✔
71
    def heading(self):
72
        """a dj.Heading object, reflects the effects of the projection operator .proj"""
73
        return self._heading
1✔
74

75
    @property
1✔
76
    def original_heading(self):
77
        """a dj.Heading object reflecting the attributes before projection"""
78
        return self._original_heading or self.heading
1✔
79

80
    @property
1✔
81
    def restriction(self):
82
        """a AndList object of restrictions applied to input to produce the result"""
83
        if self._restriction is None:
1✔
84
            self._restriction = AndList()
1✔
85
        return self._restriction
1✔
86

87
    @property
1✔
88
    def restriction_attributes(self):
89
        """the set of attribute names invoked in the WHERE clause"""
90
        if self._restriction_attributes is None:
1✔
91
            self._restriction_attributes = set()
1✔
92
        return self._restriction_attributes
1✔
93

94
    @property
1✔
95
    def primary_key(self):
96
        return self.heading.primary_key
1✔
97

98
    _subquery_alias_count = count()  # count for alias names used in the FROM clause
1✔
99

100
    def from_clause(self):
1✔
101
        support = (
1✔
102
            "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count)
103
            if isinstance(src, QueryExpression)
104
            else src
105
            for src in self.support
106
        )
107
        clause = next(support)
1✔
108
        for s, left in zip(support, self._left):
1✔
109
            clause += " NATURAL{left} JOIN {clause}".format(
1✔
110
                left=" LEFT" if left else "", clause=s
111
            )
112
        return clause
1✔
113

114
    def where_clause(self):
1✔
115
        return (
1✔
116
            ""
117
            if not self.restriction
118
            else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction)
119
        )
120

121
    def make_sql(self, fields=None):
1✔
122
        """
123
        Make the SQL SELECT statement.
124

125
        :param fields: used to explicitly set the select attributes
126
        """
127
        return "SELECT {distinct}{fields} FROM {from_}{where}".format(
1✔
128
            distinct="DISTINCT " if self._distinct else "",
129
            fields=self.heading.as_sql(fields or self.heading.names),
130
            from_=self.from_clause(),
131
            where=self.where_clause(),
132
        )
133

134
    # --------- query operators -----------
135
    def make_subquery(self):
1✔
136
        """create a new SELECT statement where self is the FROM clause"""
137
        result = QueryExpression()
1✔
138
        result._connection = self.connection
1✔
139
        result._support = [self]
1✔
140
        result._heading = self.heading.make_subquery_heading()
1✔
141
        return result
1✔
142

143
    def restrict(self, restriction):
1✔
144
        """
145
        Produces a new expression with the new restriction applied.
146
        rel.restrict(restriction)  is equivalent to  rel & restriction.
147
        rel.restrict(Not(restriction))  is equivalent to  rel - restriction
148
        The primary key of the result is unaffected.
149
        Successive restrictions are combined as logical AND:   r & a & b  is equivalent to r & AndList((a, b))
150
        Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists
151
        (logical disjunction of conditions)
152
        Inverse restriction is accomplished by either using the subtraction operator or the Not class.
153

154
        The expressions in each row equivalent:
155

156
        rel & True                          rel
157
        rel & False                         the empty entity set
158
        rel & 'TRUE'                        rel
159
        rel & 'FALSE'                       the empty entity set
160
        rel - cond                          rel & Not(cond)
161
        rel - 'TRUE'                        rel & False
162
        rel - 'FALSE'                       rel
163
        rel & AndList((cond1,cond2))        rel & cond1 & cond2
164
        rel & AndList()                     rel
165
        rel & [cond1, cond2]                rel & OrList((cond1, cond2))
166
        rel & []                            rel & False
167
        rel & None                          rel & False
168
        rel & any_empty_entity_set          rel & False
169
        rel - AndList((cond1,cond2))        rel & [Not(cond1), Not(cond2)]
170
        rel - [cond1, cond2]                rel & Not(cond1) & Not(cond2)
171
        rel - AndList()                     rel & False
172
        rel - []                            rel
173
        rel - None                          rel
174
        rel - any_empty_entity_set          rel
175

176
        When arg is another QueryExpression, the restriction  rel & arg  restricts rel to elements that match at least
177
        one element in arg (hence arg is treated as an OrList).
178
        Conversely,  rel - arg  restricts rel to elements that do not match any elements in arg.
179
        Two elements match when their common attributes have equal values or when they have no common attributes.
180
        All shared attributes must be in the primary key of either rel or arg or both or an error will be raised.
181

182
        QueryExpression.restrict is the only access point that modifies restrictions. All other operators must
183
        ultimately call restrict()
184

185
        :param restriction: a sequence or an array (treated as OR list), another QueryExpression, an SQL condition
186
        string, or an AndList.
187
        """
188
        attributes = set()
1✔
189
        new_condition = make_condition(self, restriction, attributes)
1✔
190
        if new_condition is True:
1✔
191
            return self  # restriction has no effect, return the same object
1✔
192
        # check that all attributes in condition are present in the query
193
        try:
1✔
194
            raise DataJointError(
1✔
195
                "Attribute `%s` is not found in query."
196
                % next(attr for attr in attributes if attr not in self.heading.names)
197
            )
198
        except StopIteration:
1✔
199
            pass  # all ok
1✔
200
        # If the new condition uses any new attributes, a subquery is required.
201
        # However, Aggregation's HAVING statement works fine with aliased attributes.
202
        need_subquery = isinstance(self, Union) or (
1✔
203
            not isinstance(self, Aggregation) and self.heading.new_attributes
204
        )
205
        if need_subquery:
1✔
206
            result = self.make_subquery()
1✔
207
        else:
208
            result = copy.copy(self)
1✔
209
            result._restriction = AndList(
1✔
210
                self.restriction
211
            )  # copy to preserve the original
212
        result.restriction.append(new_condition)
1✔
213
        result.restriction_attributes.update(attributes)
1✔
214
        return result
1✔
215

216
    def restrict_in_place(self, restriction):
1✔
217
        self.__dict__.update(self.restrict(restriction).__dict__)
×
218

219
    def __and__(self, restriction):
1✔
220
        """
221
        Restriction operator e.g. ``q1 & q2``.
222
        :return: a restricted copy of the input argument
223
        See QueryExpression.restrict for more detail.
224
        """
225
        return self.restrict(restriction)
1✔
226

227
    def __xor__(self, restriction):
1✔
228
        """
229
        Permissive restriction operator ignoring compatibility check  e.g. ``q1 ^ q2``.
230
        """
231
        if inspect.isclass(restriction) and issubclass(restriction, QueryExpression):
1✔
232
            restriction = restriction()
1✔
233
        if isinstance(restriction, Not):
1✔
234
            return self.restrict(Not(PromiscuousOperand(restriction.restriction)))
×
235
        return self.restrict(PromiscuousOperand(restriction))
1✔
236

237
    def __sub__(self, restriction):
1✔
238
        """
239
        Inverted restriction e.g. ``q1 - q2``.
240
        :return: a restricted copy of the input argument
241
        See QueryExpression.restrict for more detail.
242
        """
243
        return self.restrict(Not(restriction))
1✔
244

245
    def __neg__(self):
1✔
246
        """
247
        Convert between restriction and inverted restriction e.g. ``-q1``.
248
        :return: target restriction
249
        See QueryExpression.restrict for more detail.
250
        """
251
        if isinstance(self, Not):
×
252
            return self.restriction
×
253
        return Not(self)
×
254

255
    def __mul__(self, other):
1✔
256
        """
257
        join of query expressions `self` and `other` e.g. ``q1 * q2``.
258
        """
259
        return self.join(other)
1✔
260

261
    def __matmul__(self, other):
1✔
262
        """
263
        Permissive join of query expressions `self` and `other` ignoring compatibility check
264
            e.g. ``q1 @ q2``.
265
        """
266
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
267
            other = other()  # instantiate
1✔
268
        return self.join(other, semantic_check=False)
1✔
269

270
    def join(self, other, semantic_check=True, left=False):
1✔
271
        """
272
        create the joined QueryExpression.
273
        a * b  is short for A.join(B)
274
        a @ b  is short for A.join(B, semantic_check=False)
275
        Additionally, left=True will retain the rows of self, effectively performing a left join.
276
        """
277
        # trigger subqueries if joining on renamed attributes
278
        if isinstance(other, U):
1✔
279
            return other * self
1✔
280
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
281
            other = other()  # instantiate
1✔
282
        if not isinstance(other, QueryExpression):
1✔
283
            raise DataJointError("The argument of join must be a QueryExpression")
×
284
        if semantic_check:
1✔
285
            assert_join_compatibility(self, other)
1✔
286
        join_attributes = set(n for n in self.heading.names if n in other.heading.names)
1✔
287
        # needs subquery if self's FROM clause has common attributes with other's FROM clause
288
        need_subquery1 = need_subquery2 = bool(
1✔
289
            (set(self.original_heading.names) & set(other.original_heading.names))
290
            - join_attributes
291
        )
292
        # need subquery if any of the join attributes are derived
293
        need_subquery1 = (
1✔
294
            need_subquery1
295
            or isinstance(self, Aggregation)
296
            or any(n in self.heading.new_attributes for n in join_attributes)
297
            or isinstance(self, Union)
298
        )
299
        need_subquery2 = (
1✔
300
            need_subquery2
301
            or isinstance(other, Aggregation)
302
            or any(n in other.heading.new_attributes for n in join_attributes)
303
            or isinstance(self, Union)
304
        )
305
        if need_subquery1:
1✔
306
            self = self.make_subquery()
1✔
307
        if need_subquery2:
1✔
308
            other = other.make_subquery()
1✔
309
        result = QueryExpression()
1✔
310
        result._connection = self.connection
1✔
311
        result._support = self.support + other.support
1✔
312
        result._left = self._left + [left] + other._left
1✔
313
        result._heading = self.heading.join(other.heading)
1✔
314
        result._restriction = AndList(self.restriction)
1✔
315
        result._restriction.append(other.restriction)
1✔
316
        result._original_heading = self.original_heading.join(other.original_heading)
1✔
317
        assert len(result.support) == len(result._left) + 1
1✔
318
        return result
1✔
319

320
    def __add__(self, other):
1✔
321
        """union e.g. ``q1 + q2``."""
322
        return Union.create(self, other)
1✔
323

324
    def proj(self, *attributes, **named_attributes):
1✔
325
        """
326
        Projection operator.
327

328
        :param attributes:  attributes to be included in the result. (The primary key is already included).
329
        :param named_attributes: new attributes computed or renamed from existing attributes.
330
        :return: the projected expression.
331
        Primary key attributes cannot be excluded but may be renamed.
332
        If the attribute list contains an Ellipsis ..., then all secondary attributes are included too
333
        Prefixing an attribute name with a dash '-attr' removes the attribute from the list if present.
334
        Keyword arguments can be used to rename attributes as in name='attr', duplicate them as in name='(attr)', or
335
        self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self)
336
        self.proj() -- include only primary key
337
        self.proj('attr1', 'attr2')  -- include primary key and attributes attr1 and attr2
338
        self.proj(..., '-attr1', '-attr2')  -- include all attributes except attr1 and attr2
339
        self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1
340
        self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup'
341
        self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax)
342
        from other attributes available before the projection.
343
        Each attribute name can only be used once.
344
        """
345
        # new attributes in parentheses are included again with the new name without removing original
346
        duplication_pattern = re.compile(
1✔
347
            rf'^\s*\(\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*\)\s*$'
348
        )
349
        # attributes without parentheses renamed
350
        rename_pattern = re.compile(
1✔
351
            rf'^\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*$'
352
        )
353
        replicate_map = {
1✔
354
            k: m.group("name")
355
            for k, m in (
356
                (k, duplication_pattern.match(v)) for k, v in named_attributes.items()
357
            )
358
            if m
359
        }
360
        rename_map = {
1✔
361
            k: m.group("name")
362
            for k, m in (
363
                (k, rename_pattern.match(v)) for k, v in named_attributes.items()
364
            )
365
            if m
366
        }
367
        compute_map = {
1✔
368
            k: v
369
            for k, v in named_attributes.items()
370
            if not duplication_pattern.match(v) and not rename_pattern.match(v)
371
        }
372
        attributes = set(attributes)
1✔
373
        # include primary key
374
        attributes.update((k for k in self.primary_key if k not in rename_map.values()))
1✔
375
        # include all secondary attributes with Ellipsis
376
        if Ellipsis in attributes:
1✔
377
            attributes.discard(Ellipsis)
1✔
378
            attributes.update(
1✔
379
                (
380
                    a
381
                    for a in self.heading.secondary_attributes
382
                    if a not in attributes and a not in rename_map.values()
383
                )
384
            )
385
        try:
1✔
386
            raise DataJointError(
1✔
387
                "%s is not a valid data type for an attribute in .proj"
388
                % next(a for a in attributes if not isinstance(a, str))
389
            )
390
        except StopIteration:
1✔
391
            pass  # normal case
1✔
392
        # remove excluded attributes, specified as `-attr'
393
        excluded = set(a for a in attributes if a.strip().startswith("-"))
1✔
394
        attributes.difference_update(excluded)
1✔
395
        excluded = set(a.lstrip("-").strip() for a in excluded)
1✔
396
        attributes.difference_update(excluded)
1✔
397
        try:
1✔
398
            raise DataJointError(
1✔
399
                "Cannot exclude primary key attribute %s",
400
                next(a for a in excluded if a in self.primary_key),
401
            )
402
        except StopIteration:
1✔
403
            pass  # all ok
1✔
404
        # check that all attributes exist in heading
405
        try:
1✔
406
            raise DataJointError(
1✔
407
                "Attribute `%s` not found."
408
                % next(a for a in attributes if a not in self.heading.names)
409
            )
410
        except StopIteration:
1✔
411
            pass  # all ok
1✔
412

413
        # check that all mentioned names are present in heading
414
        mentions = attributes.union(replicate_map.values()).union(rename_map.values())
1✔
415
        try:
1✔
416
            raise DataJointError(
1✔
417
                "Attribute '%s' not found."
418
                % next(a for a in mentions if not self.heading.names)
419
            )
420
        except StopIteration:
1✔
421
            pass  # all ok
1✔
422

423
        # check that newly created attributes do not clash with any other selected attributes
424
        try:
1✔
425
            raise DataJointError(
1✔
426
                "Attribute `%s` already exists"
427
                % next(
428
                    a
429
                    for a in rename_map
430
                    if a in attributes.union(compute_map).union(replicate_map)
431
                )
432
            )
433
        except StopIteration:
1✔
434
            pass  # all ok
1✔
435
        try:
1✔
436
            raise DataJointError(
1✔
437
                "Attribute `%s` already exists"
438
                % next(
439
                    a
440
                    for a in compute_map
441
                    if a in attributes.union(rename_map).union(replicate_map)
442
                )
443
            )
444
        except StopIteration:
1✔
445
            pass  # all ok
1✔
446
        try:
1✔
447
            raise DataJointError(
1✔
448
                "Attribute `%s` already exists"
449
                % next(
450
                    a
451
                    for a in replicate_map
452
                    if a in attributes.union(rename_map).union(compute_map)
453
                )
454
            )
455
        except StopIteration:
1✔
456
            pass  # all ok
1✔
457

458
        # need a subquery if the projection remaps any remapped attributes
459
        used = set(q for v in compute_map.values() for q in extract_column_names(v))
1✔
460
        used.update(rename_map.values())
1✔
461
        used.update(replicate_map.values())
1✔
462
        used.intersection_update(self.heading.names)
1✔
463
        need_subquery = isinstance(self, Union) or any(
1✔
464
            self.heading[name].attribute_expression is not None for name in used
465
        )
466
        if not need_subquery and self.restriction:
1✔
467
            # need a subquery if the restriction applies to attributes that have been renamed
468
            need_subquery = any(
1✔
469
                name in self.restriction_attributes
470
                for name in self.heading.new_attributes
471
            )
472

473
        result = self.make_subquery() if need_subquery else copy.copy(self)
1✔
474
        result._original_heading = result.original_heading
1✔
475
        result._heading = result.heading.select(
1✔
476
            attributes,
477
            rename_map=dict(**rename_map, **replicate_map),
478
            compute_map=compute_map,
479
        )
480
        return result
1✔
481

482
    def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes):
1✔
483
        """
484
        Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
485
        has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
486

487
        :param group:  The query expression to be aggregated.
488
        :param keep_all_rows: True=keep all the rows from self. False=keep only rows that match entries in group.
489
        :param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
490
        :return: The derived query expression
491
        """
492
        if Ellipsis in attributes:
1✔
493
            # expand ellipsis to include only attributes from the left table
494
            attributes = set(attributes)
1✔
495
            attributes.discard(Ellipsis)
1✔
496
            attributes.update(self.heading.secondary_attributes)
1✔
497
        return Aggregation.create(self, group=group, keep_all_rows=keep_all_rows).proj(
1✔
498
            *attributes, **named_attributes
499
        )
500

501
    aggregate = aggr  # alias for aggr
1✔
502

503
    # ---------- Fetch operators --------------------
504
    @property
1✔
505
    def fetch1(self):
506
        return Fetch1(self)
1✔
507

508
    @property
1✔
509
    def fetch(self):
510
        return Fetch(self)
1✔
511

512
    def head(self, limit=25, **fetch_kwargs):
1✔
513
        """
514
        shortcut to fetch the first few entries from query expression.
515
        Equivalent to fetch(order_by="KEY", limit=25)
516

517
        :param limit:  number of entries
518
        :param fetch_kwargs: kwargs for fetch
519
        :return: query result
520
        """
521
        return self.fetch(order_by="KEY", limit=limit, **fetch_kwargs)
1✔
522

523
    def tail(self, limit=25, **fetch_kwargs):
1✔
524
        """
525
        shortcut to fetch the last few entries from query expression.
526
        Equivalent to fetch(order_by="KEY DESC", limit=25)[::-1]
527

528
        :param limit:  number of entries
529
        :param fetch_kwargs: kwargs for fetch
530
        :return: query result
531
        """
532
        return self.fetch(order_by="KEY DESC", limit=limit, **fetch_kwargs)[::-1]
1✔
533

534
    def __len__(self):
1✔
535
        """:return: number of elements in the result set e.g. ``len(q1)``."""
536
        return self.connection.query(
1✔
537
            "SELECT {select_} FROM {from_}{where}".format(
538
                select_=(
539
                    "count(*)"
540
                    if any(self._left)
541
                    else "count(DISTINCT {fields})".format(
542
                        fields=self.heading.as_sql(
543
                            self.primary_key, include_aliases=False
544
                        )
545
                    )
546
                ),
547
                from_=self.from_clause(),
548
                where=self.where_clause(),
549
            )
550
        ).fetchone()[0]
551

552
    def __bool__(self):
1✔
553
        """
554
        :return: True if the result is not empty. Equivalent to len(self) > 0 but often
555
            faster e.g. ``bool(q1)``.
556
        """
557
        return bool(
1✔
558
            self.connection.query(
559
                "SELECT EXISTS(SELECT 1 FROM {from_}{where})".format(
560
                    from_=self.from_clause(), where=self.where_clause()
561
                )
562
            ).fetchone()[0]
563
        )
564

565
    def __contains__(self, item):
1✔
566
        """
567
        returns True if the restriction in item matches any entries in self
568
            e.g. ``restriction in q1``.
569

570
        :param item: any restriction
571
        (item in query_expression) is equivalent to bool(query_expression & item) but may be
572
        executed more efficiently.
573
        """
574
        return bool(self & item)  # May be optimized e.g. using an EXISTS query
1✔
575

576
    def __iter__(self):
1✔
577
        """
578
        returns an iterator-compatible QueryExpression object e.g. ``iter(q1)``.
579

580
        :param self: iterator-compatible QueryExpression object
581
        """
582
        self._iter_only_key = all(v.in_key for v in self.heading.attributes.values())
1✔
583
        self._iter_keys = self.fetch("KEY")
1✔
584
        return self
1✔
585

586
    def __next__(self):
1✔
587
        """
588
        returns the next record on an iterator-compatible QueryExpression object
589
            e.g. ``next(q1)``.
590

591
        :param self: A query expression
592
        :type self: :class:`QueryExpression`
593
        :rtype: dict
594
        """
595
        try:
1✔
596
            key = self._iter_keys.pop(0)
1✔
597
        except AttributeError:
1✔
598
            # self._iter_keys is missing because __iter__ has not been called.
599
            raise TypeError(
×
600
                "A QueryExpression object is not an iterator. "
601
                "Use iter(obj) to create an iterator."
602
            )
603
        except IndexError:
1✔
604
            raise StopIteration
1✔
605
        else:
606
            if self._iter_only_key:
1✔
607
                return key
×
608
            else:
609
                try:
1✔
610
                    return (self & key).fetch1()
1✔
611
                except DataJointError:
×
612
                    # The data may have been deleted since the moment the keys were fetched
613
                    # -- move on to next entry.
614
                    return next(self)
×
615

616
    def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
1✔
617
        """
618
        See expression.fetch() for input description.
619
        :return: query cursor
620
        """
621
        if offset and limit is None:
1✔
622
            raise DataJointError("limit is required when offset is set")
×
623
        sql = self.make_sql()
1✔
624
        if order_by is not None:
1✔
625
            sql += " ORDER BY " + ", ".join(order_by)
1✔
626
        if limit is not None:
1✔
627
            sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "")
1✔
628
        logger.debug(sql)
1✔
629
        return self.connection.query(sql, as_dict=as_dict)
1✔
630

631
    def __repr__(self):
1✔
632
        """
633
        returns the string representation of a QueryExpression object e.g. ``str(q1)``.
634

635
        :param self: A query expression
636
        :type self: :class:`QueryExpression`
637
        :rtype: str
638
        """
639
        return (
1✔
640
            super().__repr__()
641
            if config["loglevel"].lower() == "debug"
642
            else self.preview()
643
        )
644

645
    def preview(self, limit=None, width=None):
1✔
646
        """:return: a string of preview of the contents of the query."""
647
        return preview(self, limit, width)
1✔
648

649
    def _repr_html_(self):
1✔
650
        """:return: HTML to display table in Jupyter notebook."""
651
        return repr_html(self)
1✔
652

653

654
class Aggregation(QueryExpression):
1✔
655
    """
656
    Aggregation.create(arg, group, comp1='calc1', ..., compn='calcn')  yields an entity set
657
    with primary key from arg.
658
    The computed arguments comp1, ..., compn use aggregation calculations on the attributes of
659
    group or simple projections and calculations on the attributes of arg.
660
    Aggregation is used QueryExpression.aggr and U.aggr.
661
    Aggregation is a private class in DataJoint, not exposed to users.
662
    """
663

664
    _left_restrict = None  # the pre-GROUP BY conditions for the WHERE clause
1✔
665
    _subquery_alias_count = count()
1✔
666

667
    @classmethod
1✔
668
    def create(cls, arg, group, keep_all_rows=False):
1✔
669
        if inspect.isclass(group) and issubclass(group, QueryExpression):
1✔
670
            group = group()  # instantiate if a class
1✔
671
        assert isinstance(group, QueryExpression)
1✔
672
        if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes:
1✔
673
            group = group.make_subquery()  # subquery if left joining a join
1✔
674
        join = arg.join(group, left=keep_all_rows)  # reuse the join logic
1✔
675
        result = cls()
1✔
676
        result._connection = join.connection
1✔
677
        result._heading = join.heading.set_primary_key(
1✔
678
            arg.primary_key
679
        )  # use left operand's primary key
680
        result._support = join.support
1✔
681
        result._left = join._left
1✔
682
        result._left_restrict = join.restriction  # WHERE clause applied before GROUP BY
1✔
683
        result._grouping_attributes = result.primary_key
1✔
684

685
        return result
1✔
686

687
    def where_clause(self):
1✔
688
        return (
1✔
689
            ""
690
            if not self._left_restrict
691
            else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict)
692
        )
693

694
    def make_sql(self, fields=None):
1✔
695
        fields = self.heading.as_sql(fields or self.heading.names)
1✔
696
        assert self._grouping_attributes or not self.restriction
1✔
697
        distinct = set(self.heading.names) == set(self.primary_key)
1✔
698
        return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
1✔
699
            distinct="DISTINCT " if distinct else "",
700
            fields=fields,
701
            from_=self.from_clause(),
702
            where=self.where_clause(),
703
            group_by=""
704
            if not self.primary_key
705
            else (
706
                " GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
707
                + (
708
                    ""
709
                    if not self.restriction
710
                    else " HAVING (%s)" % ")AND(".join(self.restriction)
711
                )
712
            ),
713
        )
714

715
    def __len__(self):
1✔
716
        return self.connection.query(
1✔
717
            "SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
718
                subquery=self.make_sql(), alias=next(self._subquery_alias_count)
719
            )
720
        ).fetchone()[0]
721

722
    def __bool__(self):
1✔
723
        return bool(
1✔
724
            self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))
725
        )
726

727

728
class Union(QueryExpression):
1✔
729
    """
730
    Union is the private DataJoint class that implements the union operator.
731
    """
732

733
    __count = count()
1✔
734

735
    @classmethod
1✔
736
    def create(cls, arg1, arg2):
737
        if inspect.isclass(arg2) and issubclass(arg2, QueryExpression):
1✔
738
            arg2 = arg2()  # instantiate if a class
1✔
739
        if not isinstance(arg2, QueryExpression):
1✔
740
            raise DataJointError(
×
741
                "A QueryExpression can only be unioned with another QueryExpression"
742
            )
743
        if arg1.connection != arg2.connection:
1✔
744
            raise DataJointError(
×
745
                "Cannot operate on QueryExpressions originating from different connections."
746
            )
747
        if set(arg1.primary_key) != set(arg2.primary_key):
1✔
748
            raise DataJointError(
×
749
                "The operands of a union must share the same primary key."
750
            )
751
        if set(arg1.heading.secondary_attributes) & set(
1✔
752
            arg2.heading.secondary_attributes
753
        ):
754
            raise DataJointError(
×
755
                "The operands of a union must not share any secondary attributes."
756
            )
757
        result = cls()
1✔
758
        result._connection = arg1.connection
1✔
759
        result._heading = arg1.heading.join(arg2.heading)
1✔
760
        result._support = [arg1, arg2]
1✔
761
        return result
1✔
762

763
    def make_sql(self):
1✔
764
        arg1, arg2 = self._support
1✔
765
        if (
1✔
766
            not arg1.heading.secondary_attributes
767
            and not arg2.heading.secondary_attributes
768
        ):
769
            # no secondary attributes: use UNION DISTINCT
770
            fields = arg1.primary_key
1✔
771
            return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format(
1✔
772
                sql1=arg1.make_sql()
773
                if isinstance(arg1, Union)
774
                else arg1.make_sql(fields),
775
                sql2=arg2.make_sql()
776
                if isinstance(arg2, Union)
777
                else arg2.make_sql(fields),
778
                alias=next(self.__count),
779
            )
780
        # with secondary attributes, use union of left join with antijoin
781
        fields = self.heading.names
1✔
782
        sql1 = arg1.join(arg2, left=True).make_sql(fields)
1✔
783
        sql2 = (
1✔
784
            (arg2 - arg1)
785
            .proj(..., **{k: "NULL" for k in arg1.heading.secondary_attributes})
786
            .make_sql(fields)
787
        )
788
        return "({sql1})  UNION ({sql2})".format(sql1=sql1, sql2=sql2)
1✔
789

790
    def from_clause(self):
1✔
791
        """The union does not use a FROM clause"""
792
        assert False
×
793

794
    def where_clause(self):
1✔
795
        """The union does not use a WHERE clause"""
796
        assert False
×
797

798
    def __len__(self):
1✔
799
        return self.connection.query(
1✔
800
            "SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
801
                subquery=self.make_sql(),
802
                alias=next(QueryExpression._subquery_alias_count),
803
            )
804
        ).fetchone()[0]
805

806
    def __bool__(self):
1✔
807
        return bool(
×
808
            self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))
809
        )
810

811

812
class U:
1✔
813
    """
814
    dj.U objects are the universal sets representing all possible values of their attributes.
815
    dj.U objects cannot be queried on their own but are useful for forming some queries.
816
    dj.U('attr1', ..., 'attrn') represents the universal set with the primary key attributes attr1 ... attrn.
817
    The universal set is the set of all possible combinations of values of the attributes.
818
    Without any attributes, dj.U() represents the set with one element that has no attributes.
819

820
    Restriction:
821

822
    dj.U can be used to enumerate unique combinations of values of attributes from other expressions.
823

824
    The following expression yields all unique combinations of contrast and brightness found in the `stimulus` set:
825

826
    >>> dj.U('contrast', 'brightness') & stimulus
827

828
    Aggregation:
829

830
    In aggregation, dj.U is used for summary calculation over an entire set:
831

832
    The following expression yields one element with one attribute `s` containing the total number of elements in
833
    query expression `expr`:
834

835
    >>> dj.U().aggr(expr, n='count(*)')
836

837
    The following expressions both yield one element containing the number `n` of distinct values of attribute `attr` in
838
    query expressio `expr`.
839

840
    >>> dj.U().aggr(expr, n='count(distinct attr)')
841
    >>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
842

843
    The following expression yields one element and one attribute `s` containing the sum of values of attribute `attr`
844
    over entire result set of expression `expr`:
845

846
    >>> dj.U().aggr(expr, s='sum(attr)')
847

848
    The following expression yields the set of all unique combinations of attributes `attr1`, `attr2` and the number of
849
    their occurrences in the result set of query expression `expr`.
850

851
    >>> dj.U(attr1,attr2).aggr(expr, n='count(*)')
852

853
    Joins:
854

855
    If expression `expr` has attributes 'attr1' and 'attr2', then expr * dj.U('attr1','attr2') yields the same result
856
    as `expr` but `attr1` and `attr2` are promoted to the the primary key.  This is useful for producing a join on
857
    non-primary key attributes.
858
    For example, if `attr` is in both expr1 and expr2 but not in their primary keys, then expr1 * expr2 will throw
859
    an error because in most cases, it does not make sense to join on non-primary key attributes and users must first
860
    rename `attr` in one of the operands.  The expression dj.U('attr') * rel1 * rel2 overrides this constraint.
861
    """
862

863
    def __init__(self, *primary_key):
1✔
864
        self._primary_key = primary_key
1✔
865

866
    @property
1✔
867
    def primary_key(self):
868
        return self._primary_key
1✔
869

870
    def __and__(self, other):
1✔
871
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
872
            other = other()  # instantiate if a class
×
873
        if not isinstance(other, QueryExpression):
1✔
874
            raise DataJointError("Set U can only be restricted with a QueryExpression.")
1✔
875
        result = copy.copy(other)
1✔
876
        result._distinct = True
1✔
877
        result._heading = result.heading.set_primary_key(self.primary_key)
1✔
878
        result = result.proj()
1✔
879
        return result
1✔
880

881
    def join(self, other, left=False):
1✔
882
        """
883
        Joining U with a query expression has the effect of promoting the attributes of U to
884
        the primary key of the other query expression.
885

886
        :param other: the other query expression to join with.
887
        :param left: ignored. dj.U always acts as if left=False
888
        :return: a copy of the other query expression with the primary key extended.
889
        """
890
        if inspect.isclass(other) and issubclass(other, QueryExpression):
1✔
891
            other = other()  # instantiate if a class
×
892
        if not isinstance(other, QueryExpression):
1✔
893
            raise DataJointError("Set U can only be joined with a QueryExpression.")
1✔
894
        try:
1✔
895
            raise DataJointError(
1✔
896
                "Attribute `%s` not found"
897
                % next(k for k in self.primary_key if k not in other.heading.names)
898
            )
899
        except StopIteration:
1✔
900
            pass  # all ok
1✔
901
        result = copy.copy(other)
1✔
902
        result._heading = result.heading.set_primary_key(
1✔
903
            other.primary_key
904
            + [k for k in self.primary_key if k not in other.primary_key]
905
        )
906
        return result
1✔
907

908
    def __mul__(self, other):
1✔
909
        """shorthand for join"""
910
        return self.join(other)
1✔
911

912
    def aggr(self, group, **named_attributes):
1✔
913
        """
914
        Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
915
        has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
916

917
        :param group:  The query expression to be aggregated.
918
        :param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
919
        :return: The derived query expression
920
        """
921
        if named_attributes.get("keep_all_rows", False):
1✔
922
            raise DataJointError(
×
923
                "Cannot set keep_all_rows=True when aggregating on a universal set."
924
            )
925
        return Aggregation.create(self, group=group, keep_all_rows=False).proj(
1✔
926
            **named_attributes
927
        )
928

929
    aggregate = aggr  # alias for aggr
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc