• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12898

pending completion
#12898

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">715ab4055<a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/0a4f193031d8b1e14b09ec62d83c5def3b7421b0">0a4f19303</a>">0a4f19303</a><a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9"> into 3b6e84588">3b6e84588</a>

69 of 69 new or added lines in 9 files covered. (100.0%)

3052 of 3381 relevant lines covered (90.27%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.23
/datajoint/condition.py
1
""" methods for generating SQL WHERE clauses from datajoint restriction conditions """
2

3
import inspect
1✔
4
import collections
1✔
5
import re
1✔
6
import uuid
1✔
7
import datetime
1✔
8
import decimal
1✔
9
import numpy
1✔
10
import pandas
1✔
11
import json
1✔
12
from .errors import DataJointError
1✔
13

14
JSON_PATTERN = re.compile(
1✔
15
    r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$"
16
)
17

18

19
def translate_attribute(key):
1✔
20
    match = JSON_PATTERN.match(key)
1✔
21
    if match is None:
1✔
22
        return match, key
1✔
23
    match = match.groupdict()
1✔
24
    if match["path"] is None:
1✔
25
        return match, match["attr"]
1✔
26
    else:
27
        return match, "json_value(`{}`, _utf8mb4'$.{}'{})".format(
×
28
            *{
29
                k: ((f" returning {v}" if k == "type" else v) if v else "")
30
                for k, v in match.items()
31
            }.values()
32
        )
33

34

35
class PromiscuousOperand:
1✔
36
    """
37
    A container for an operand to ignore join compatibility
38
    """
39

40
    def __init__(self, operand):
1✔
41
        self.operand = operand
1✔
42

43

44
class AndList(list):
1✔
45
    """
46
    A list of conditions to by applied to a query expression by logical conjunction: the
47
    conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
48
    applied by logical disjunction (OR).
49

50
    Example:
51
    expr2 = expr & dj.AndList((cond1, cond2, cond3))
52
    is equivalent to
53
    expr2 = expr & cond1 & cond2 & cond3
54
    """
55

56
    def append(self, restriction):
1✔
57
        if isinstance(restriction, AndList):
1✔
58
            # extend to reduce nesting
59
            self.extend(restriction)
1✔
60
        else:
61
            super().append(restriction)
1✔
62

63

64
class Not:
1✔
65
    """invert restriction"""
66

67
    def __init__(self, restriction):
1✔
68
        self.restriction = restriction
1✔
69

70

71
def assert_join_compatibility(expr1, expr2):
1✔
72
    """
73
    Determine if expressions expr1 and expr2 are join-compatible.  To be join-compatible,
74
    the matching attributes in the two expressions must be in the primary key of one or the
75
    other expression.
76
    Raises an exception if not compatible.
77

78
    :param expr1: A QueryExpression object
79
    :param expr2: A QueryExpression object
80
    """
81
    from .expression import QueryExpression, U
1✔
82

83
    for rel in (expr1, expr2):
1✔
84
        if not isinstance(rel, (U, QueryExpression)):
1✔
85
            raise DataJointError(
×
86
                "Object %r is not a QueryExpression and cannot be joined." % rel
87
            )
88
    if not isinstance(expr1, U) and not isinstance(
1✔
89
        expr2, U
90
    ):  # dj.U is always compatible
91
        try:
1✔
92
            raise DataJointError(
1✔
93
                "Cannot join query expressions on dependent attribute `%s`"
94
                % next(
95
                    r
96
                    for r in set(expr1.heading.secondary_attributes).intersection(
97
                        expr2.heading.secondary_attributes
98
                    )
99
                )
100
            )
101
        except StopIteration:
1✔
102
            pass  # all ok
1✔
103

104

105
def make_condition(query_expression, condition, columns):
1✔
106
    """
107
    Translate the input condition into the equivalent SQL condition (a string)
108

109
    :param query_expression: a dj.QueryExpression object to apply condition
110
    :param condition: any valid restriction object.
111
    :param columns: a set passed by reference to collect all column names used in the
112
        condition.
113
    :return: an SQL condition string or a boolean value.
114
    """
115
    from .expression import QueryExpression, Aggregation, U
1✔
116

117
    def prep_value(k, v):
1✔
118
        """prepare SQL condition"""
119
        key_match, k = translate_attribute(k)
1✔
120
        if key_match["path"] is None:
1✔
121
            k = f"`{k}`"
1✔
122
        if (
1✔
123
            query_expression.heading[key_match["attr"]].json
124
            and key_match["path"] is not None
125
            and isinstance(v, dict)
126
        ):
127
            return f"{k}='{json.dumps(v)}'"
×
128
        if v is None:
1✔
129
            return f"{k} IS NULL"
1✔
130
        if query_expression.heading[key_match["attr"]].uuid:
1✔
131
            if not isinstance(v, uuid.UUID):
1✔
132
                try:
1✔
133
                    v = uuid.UUID(v)
1✔
134
                except (AttributeError, ValueError):
1✔
135
                    raise DataJointError(
1✔
136
                        "Badly formed UUID {v} in restriction by `{k}`".format(k=k, v=v)
137
                    )
138
            return f"{k}=X'{v.bytes.hex()}'"
1✔
139
        if isinstance(
1✔
140
            v,
141
            (
142
                datetime.date,
143
                datetime.datetime,
144
                datetime.time,
145
                decimal.Decimal,
146
                list,
147
            ),
148
        ):
149
            return f'{k}="{v}"'
1✔
150
        if isinstance(v, str):
1✔
151
            v = v.replace("%", "%%").replace("\\", "\\\\")
1✔
152
            return f'{k}="{v}"'
1✔
153
        return f"{k}={v}"
1✔
154

155
    def join_conditions(negate, restrictions, operator="AND"):
1✔
156
        return ("NOT (%s)" if negate else "%s") % (
1✔
157
            f"({f') {operator} ('.join(restrictions)})"
158
        )
159

160
    negate = False
1✔
161
    while isinstance(condition, Not):
1✔
162
        negate = not negate
1✔
163
        condition = condition.restriction
1✔
164

165
    # restrict by string
166
    if isinstance(condition, str):
1✔
167
        columns.update(extract_column_names(condition))
1✔
168
        return join_conditions(
1✔
169
            negate, restrictions=[condition.strip().replace("%", "%%")]
170
        )  # escape %, see issue #376
171

172
    # restrict by AndList
173
    if isinstance(condition, AndList):
1✔
174
        # omit all conditions that evaluate to True
175
        items = [
1✔
176
            item
177
            for item in (
178
                make_condition(query_expression, cond, columns) for cond in condition
179
            )
180
            if item is not True
181
        ]
182
        if any(item is False for item in items):
1✔
183
            return negate  # if any item is False, the whole thing is False
×
184
        if not items:
1✔
185
            return not negate  # and empty AndList is True
1✔
186
        return join_conditions(negate, restrictions=items)
1✔
187

188
    # restriction by dj.U evaluates to True
189
    if isinstance(condition, U):
1✔
190
        return not negate
1✔
191

192
    # restrict by boolean
193
    if isinstance(condition, bool):
1✔
194
        return negate != condition
×
195

196
    # restrict by a mapping/dict -- convert to an AndList of string equality conditions
197
    if isinstance(condition, collections.abc.Mapping):
1✔
198
        common_attributes = set(c.split(".", 1)[0] for c in condition).intersection(
1✔
199
            query_expression.heading.names
200
        )
201
        if not common_attributes:
1✔
202
            return not negate  # no matching attributes -> evaluates to True
1✔
203
        columns.update(common_attributes)
1✔
204
        return join_conditions(
1✔
205
            negate,
206
            restrictions=[
207
                prep_value(k, v)
208
                for k, v in condition.items()
209
                if k.split(".", 1)[0] in common_attributes
210
            ],
211
        )
212

213
    # restrict by a numpy record -- convert to an AndList of string equality conditions
214
    if isinstance(condition, numpy.void):
1✔
215
        common_attributes = set(condition.dtype.fields).intersection(
1✔
216
            query_expression.heading.names
217
        )
218
        if not common_attributes:
1✔
219
            return not negate  # no matching attributes -> evaluate to True
×
220
        columns.update(common_attributes)
1✔
221
        return join_conditions(
1✔
222
            negate,
223
            restrictions=[prep_value(k, condition[k]) for k in common_attributes],
224
        )
225

226
    # restrict by a QueryExpression subclass -- trigger instantiation and move on
227
    if inspect.isclass(condition) and issubclass(condition, QueryExpression):
1✔
228
        condition = condition()
1✔
229

230
    # restrict by another expression (aka semijoin and antijoin)
231
    check_compatibility = True
1✔
232
    if isinstance(condition, PromiscuousOperand):
1✔
233
        condition = condition.operand
1✔
234
        check_compatibility = False
1✔
235

236
    if isinstance(condition, QueryExpression):
1✔
237
        if check_compatibility:
1✔
238
            assert_join_compatibility(query_expression, condition)
1✔
239
        common_attributes = [
1✔
240
            q for q in condition.heading.names if q in query_expression.heading.names
241
        ]
242
        columns.update(common_attributes)
1✔
243
        if isinstance(condition, Aggregation):
1✔
244
            condition = condition.make_subquery()
1✔
245
        return (
1✔
246
            # without common attributes, any non-empty set matches everything
247
            (not negate if condition else negate)
248
            if not common_attributes
249
            else "({fields}) {not_}in ({subquery})".format(
250
                fields="`" + "`,`".join(common_attributes) + "`",
251
                not_="not " if negate else "",
252
                subquery=condition.make_sql(common_attributes),
253
            )
254
        )
255

256
    # restrict by pandas.DataFrames
257
    if isinstance(condition, pandas.DataFrame):
1✔
258
        condition = condition.to_records()  # convert to numpy.recarray and move on
1✔
259

260
    # if iterable (but not a string, a QueryExpression, or an AndList), treat as an OrList
261
    try:
1✔
262
        or_list = [make_condition(query_expression, q, columns) for q in condition]
1✔
263
    except TypeError:
×
264
        raise DataJointError("Invalid restriction type %r" % condition)
×
265
    else:
266
        or_list = [
1✔
267
            item for item in or_list if item is not False
268
        ]  # ignore False conditions
269
        if any(item is True for item in or_list):  # if any item is True, entirely True
1✔
270
            return not negate
×
271
        return (
1✔
272
            join_conditions(negate, restrictions=or_list, operator="OR")
273
            if or_list
274
            else negate
275
        )
276

277

278
def extract_column_names(sql_expression):
1✔
279
    """
280
    extract all presumed column names from an sql expression such as the WHERE clause,
281
    for example.
282

283
    :param sql_expression: a string containing an SQL expression
284
    :return: set of extracted column names
285
    This may be MySQL-specific for now.
286
    """
287
    assert isinstance(sql_expression, str)
1✔
288
    result = set()
1✔
289
    s = sql_expression  # for terseness
1✔
290
    # remove escaped quotes
291
    s = re.sub(r"(\\\")|(\\\')", "", s)
1✔
292
    # remove quoted text
293
    s = re.sub(r"'[^']*'", "", s)
1✔
294
    s = re.sub(r'"[^"]*"', "", s)
1✔
295
    # find all tokens in back quotes and remove them
296
    result.update(re.findall(r"`([a-z][a-z_0-9]*)`", s))
1✔
297
    s = re.sub(r"`[a-z][a-z_0-9]*`", "", s)
1✔
298
    # remove space before parentheses
299
    s = re.sub(r"\s*\(", "(", s)
1✔
300
    # remove tokens followed by ( since they must be functions
301
    s = re.sub(r"(\b[a-z][a-z_0-9]*)\(", "(", s)
1✔
302
    remaining_tokens = set(re.findall(r"\b[a-z][a-z_0-9]*\b", s))
1✔
303
    # update result removing reserved words
304
    result.update(
1✔
305
        remaining_tokens
306
        - {
307
            "is",
308
            "in",
309
            "between",
310
            "like",
311
            "and",
312
            "or",
313
            "null",
314
            "not",
315
            "interval",
316
            "second",
317
            "minute",
318
            "hour",
319
            "day",
320
            "month",
321
            "week",
322
            "year",
323
        }
324
    )
325
    return result
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc