• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

xzkostyan / clickhouse-sqlalchemy / 11406011959

18 Oct 2024 02:47PM UTC coverage: 86.142% (-0.1%) from 86.281%
11406011959

push

github

xzkostyan
Fix test_execute_full_join

2505 of 2908 relevant lines covered (86.14%)

5.12 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.73
/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py
1
from sqlalchemy import exc, literal_column
4✔
2
from sqlalchemy.sql import compiler, elements, COLLECT_CARTESIAN_PRODUCTS, \
6✔
3
    WARN_LINTING, crud
4
from sqlalchemy.sql import type_api
6✔
5
from sqlalchemy.util import inspect_getfullargspec
6✔
6

7
import clickhouse_sqlalchemy.sql.functions  # noqa:F401
6✔
8

9
from ... import types
6✔
10

11

12
class ClickHouseSQLCompiler(compiler.SQLCompiler):
6✔
13
    def visit_mod_binary(self, binary, operator, **kw):
6✔
14
        return self.process(binary.left, **kw) + ' %% ' + \
6✔
15
            self.process(binary.right, **kw)
16

17
    def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
6✔
18
        """
19
        Implementation of distinctness comparison in ClickHouse SQL.
20
        A distinctness comparison treats NULL as if it is a (singleton)
21
        value and is what ClickHouse uses for `SELECT DISTINCT` and `GROUP BY`.
22
        Some databases have direct support for a `IS DISTINCT` comparison, but
23
        ClickHouse does not, so we rely on the `hasAny` array function here.
24
        """
25

26
        return "hasAny([%s], [%s])" % (
6✔
27
            self.process(binary.left, **kw),
28
            self.process(binary.right, **kw),
29
        )
30

31
    def visit_is_distinct_from_binary(self, binary, operator, **kw):
6✔
32
        return "NOT %s" % self.visit_is_not_distinct_from_binary(
6✔
33
            binary, operator, **kw
34
        )
35

36
    def visit_empty_set_expr(self, element_types):
6✔
37
        return "SELECT %s WHERE 1!=1" % (
×
38
            ", ".join(
39
                "CAST(NULL AS %s)"
40
                % self.dialect.type_compiler.process(
41
                    t if isinstance(t, types.Nullable) else types.Nullable(t)
42
                )
43
                for t in element_types or [types.Int8()]
44
            ),
45
        )
46

47
    def post_process_text(self, text):
6✔
48
        return text.replace('%', '%%')
6✔
49

50
    def visit_count_func(self, fn, **kw):
6✔
51
        # count accepts zero arguments.
52
        return 'count%s' % self.process(fn.clause_expr, **kw)
6✔
53

54
    def visit_case(self, clause, **kwargs):
6✔
55
        text = 'CASE '
6✔
56
        if clause.value is not None:
6✔
57
            text += clause.value._compiler_dispatch(self, **kwargs) + ' '
×
58
        for cond, result in clause.whens:
6✔
59
            text += 'WHEN ' + cond._compiler_dispatch(
6✔
60
                self, **kwargs
61
            ) + ' THEN ' + result._compiler_dispatch(
62
                self, **kwargs) + " "
63

64
        if clause.else_ is not None:
6✔
65
            text += 'ELSE ' + clause.else_._compiler_dispatch(
6✔
66
                self, **kwargs
67
            ) + ' '
68

69
        text += 'END'
6✔
70
        return text
6✔
71

72
    def visit_if__func(self, func, **kw):
6✔
73
        return "(%s) ? (%s) : (%s)" % (
×
74
            self.process(func.clauses.clauses[0], **kw),
75
            self.process(func.clauses.clauses[1], **kw),
76
            self.process(func.clauses.clauses[2], **kw)
77
        )
78

79
    def limit_by_clause(self, select, **kw):
6✔
80
        text = ''
6✔
81
        limit_by_clause = select._limit_by_clause
6✔
82
        if limit_by_clause:
6✔
83
            text += ' LIMIT '
6✔
84
            if limit_by_clause.offset is not None:
6✔
85
                text += self.process(limit_by_clause.offset, **kw) + ', '
6✔
86
            text += self.process(limit_by_clause.limit, **kw)
6✔
87
            limit_by_exprs = limit_by_clause.by_clauses._compiler_dispatch(
6✔
88
                self, **kw
89
            )
90
            text += ' BY ' + limit_by_exprs
6✔
91

92
        return text
6✔
93

94
    def limit_clause(self, select, **kw):
6✔
95
        text = ''
6✔
96
        if select._limit_clause is not None:
6✔
97
            text += ' \n LIMIT '
6✔
98
            if select._offset_clause is not None:
6✔
99
                text += self.process(select._offset_clause, **kw) + ', '
6✔
100
            text += self.process(select._limit_clause, **kw)
6✔
101
        else:
102
            if select._offset_clause is not None:
6✔
103
                raise exc.CompileError('OFFSET without LIMIT is not supported')
6✔
104

105
        return text
6✔
106

107
    def visit_lambda(self, lambda_, **kw):
6✔
108
        func = lambda_.func
6✔
109
        spec = inspect_getfullargspec(func)
6✔
110

111
        if spec.varargs:
6✔
112
            raise exc.CompileError('Lambdas with *args are not supported')
6✔
113

114
        try:
6✔
115
            # ArgSpec in SA>=1.3.0b2
116
            keywords = spec.keywords
6✔
117
        except AttributeError:
6✔
118
            # FullArgSpec in SA>=1.3.0b2
119
            keywords = spec.varkw
6✔
120

121
        if keywords:
6✔
122
            raise exc.CompileError('Lambdas with **kwargs are not supported')
6✔
123

124
        text = ', '.join(spec.args) + ' -> '
6✔
125

126
        args = [literal_column(arg) for arg in spec.args]
6✔
127
        text += self.process(func(*args), **kw)
6✔
128

129
        return text
6✔
130

131
    def visit_extract(self, extract, **kw):
6✔
132
        field = self.extract_map.get(extract.field, extract.field)
6✔
133
        column = self.process(extract.expr, **kw)
6✔
134
        if field == 'year':
6✔
135
            return 'toYear(%s)' % column
6✔
136
        elif field == 'month':
6✔
137
            return 'toMonth(%s)' % column
6✔
138
        elif field == 'day':
6✔
139
            return 'toDayOfMonth(%s)' % column
6✔
140
        else:
141
            return column
6✔
142

143
    def visit_join(self, join, asfrom=False, **kwargs):
6✔
144
        text = join.left._compiler_dispatch(self, asfrom=True, **kwargs)
6✔
145

146
        if text[0] == '(' and text[-1] == ')':
6✔
147
            text = text[1:-1]
×
148

149
        flags = join.full
6✔
150
        if not isinstance(flags, dict):
6✔
151
            if isinstance(flags, tuple):
6✔
152
                flags = dict(flags)
6✔
153
            else:
154
                flags = {'full': flags}
6✔
155
        # need to make a variable to prevent leaks in some debuggers
156
        join_type = flags.get('type')
6✔
157
        if join_type is None:
6✔
158
            if flags.get('full'):
6✔
159
                join_type = 'FULL OUTER'
6✔
160
            elif join.isouter:
6✔
161
                join_type = 'LEFT OUTER'
6✔
162
            else:
163
                join_type = 'INNER'
6✔
164
        else:
165
            join_type = join_type.upper()
6✔
166
            if join.isouter and 'INNER' in join_type:
6✔
167
                raise exc.CompileError(
6✔
168
                    "can't compile join with specified "
169
                    "INNER type and isouter=True"
170
                )
171
            # isouter=False by default, disable that checking
172
            # elif not join.isouter and 'OUTER' in join.type:
173
            #     raise exc.CompileError(
174
            #         "can't compile join with specified "
175
            #         "OUTER type and isouter=False"
176
            #     )
177

178
        strictness = flags.get('strictness')
6✔
179
        if strictness:
6✔
180
            join_type = strictness.upper() + ' ' + join_type
6✔
181

182
        distribution = flags.get('distribution')
6✔
183
        if distribution:
6✔
184
            join_type = distribution.upper() + ' ' + join_type
6✔
185

186
        if join_type is not None:
6✔
187
            text += ' ' + join_type.upper() + ' JOIN '
6✔
188

189
        onclause = join.onclause
6✔
190

191
        text += join.right._compiler_dispatch(self, asfrom=True, **kwargs)
6✔
192
        if isinstance(onclause, elements.Tuple):
6✔
193
            text += ' USING ' + onclause._compiler_dispatch(
6✔
194
                self, include_table=False, **kwargs
195
            )
196
        else:
197
            text += ' ON ' + onclause._compiler_dispatch(self, **kwargs)
6✔
198
        return text
6✔
199

200
    def visit_array_join(self, array_join, **kwargs):
6✔
201
        kwargs['within_columns_clause'] = True
6✔
202

203
        return ' \nARRAY JOIN {columns}'.format(
6✔
204
            columns=', '.join(
205
                col._compiler_dispatch(self,
206
                                       within_label_clause=False,
207
                                       **kwargs)
208
                for col in array_join.clauses
209

210
            )
211
        )
212

213
    def visit_left_array_join(self, array_join, **kwargs):
6✔
214
        kwargs['within_columns_clause'] = True
6✔
215

216
        return ' \nLEFT ARRAY JOIN {columns}'.format(
6✔
217
            columns=', '.join(
218
                col._compiler_dispatch(self,
219
                                       within_label_clause=False,
220
                                       **kwargs)
221
                for col in array_join.clauses
222

223
            )
224
        )
225

226
    def visit_label(self,
6✔
227
                    label,
228
                    from_labeled_label=False,
229
                    **kw):
230
        if from_labeled_label:
6✔
231
            return super(ClickHouseSQLCompiler, self).visit_label(
6✔
232
                label,
233
                render_label_as_label=label
234
            )
235
        else:
236
            return super(ClickHouseSQLCompiler, self).visit_label(
6✔
237
                label,
238
                **kw
239
            )
240

241
    def _compose_select_body(
6✔
242
        self,
243
        text,
244
        select,
245
        compile_state,
246
        inner_columns,
247
        froms,
248
        byfrom,
249
        toplevel,
250
        kwargs,
251
    ):
252
        text += ", ".join(inner_columns)
6✔
253

254
        if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6✔
255
            from_linter = compiler.FromLinter({}, set())
6✔
256
            warn_linting = self.linting & WARN_LINTING
6✔
257
            if toplevel:
6✔
258
                self.from_linter = from_linter
6✔
259
        else:
260
            from_linter = None
6✔
261
            warn_linting = False
6✔
262

263
        if froms:
6✔
264
            text += " \nFROM "
6✔
265

266
            if select._hints:
6✔
267
                text += ", ".join(
×
268
                    [
269
                        f._compiler_dispatch(
270
                            self,
271
                            asfrom=True,
272
                            fromhints=byfrom,
273
                            from_linter=from_linter,
274
                            **kwargs
275
                        )
276
                        for f in froms
277
                    ]
278
                )
279
            else:
280
                text += ", ".join(
6✔
281
                    [
282
                        f._compiler_dispatch(
283
                            self,
284
                            asfrom=True,
285
                            from_linter=from_linter,
286
                            **kwargs
287
                        )
288
                        for f in froms
289
                    ]
290
                )
291
        else:
292
            text += self.default_from()
6✔
293

294
        sample_clause = getattr(select, '_sample_clause', None)
6✔
295

296
        if sample_clause is not None:
6✔
297
            text += self.sample_clause(select, **kwargs)
6✔
298

299
        if getattr(select, '_array_join', None) is not None:
6✔
300
            text += select._array_join._compiler_dispatch(self, **kwargs)
6✔
301

302
        final_clause = getattr(select, '_final_clause', None)
6✔
303

304
        if final_clause is not None:
6✔
305
            text += self.final_clause()
6✔
306

307
        if select._where_criteria:
6✔
308
            t = self._generate_delimited_and_list(
6✔
309
                select._where_criteria, from_linter=from_linter, **kwargs
310
            )
311
            if t:
6✔
312
                text += " \nWHERE " + t
6✔
313

314
        if warn_linting:
6✔
315
            from_linter.warn()
6✔
316

317
        if select._group_by_clauses:
6✔
318
            text += self.group_by_clause(select, **kwargs)
6✔
319

320
        if select._having_criteria:
6✔
321
            t = self._generate_delimited_and_list(
6✔
322
                select._having_criteria, **kwargs
323
            )
324
            if t:
6✔
325
                text += " \nHAVING " + t
6✔
326

327
        if select._order_by_clauses:
6✔
328
            text += self.order_by_clause(select, **kwargs)
6✔
329

330
        limit_by_clause = getattr(select, '_limit_by_clause', None)
6✔
331

332
        if limit_by_clause is not None:
6✔
333
            text += self.limit_by_clause(select, **kwargs)
6✔
334

335
        if select._has_row_limiting_clause:
6✔
336
            text += self._row_limit_clause(select, **kwargs)
6✔
337

338
        if select._for_update_arg is not None:
6✔
339
            text += self.for_update_clause(select, **kwargs)
×
340

341
        return text
6✔
342

343
    def sample_clause(self, select, **kw):
6✔
344
        return " \nSAMPLE " + self.process(select._sample_clause, **kw)
6✔
345

346
    def final_clause(self):
6✔
347
        return " \nFINAL"
6✔
348

349
    def group_by_clause(self, select, **kw):
6✔
350
        text = ""
6✔
351

352
        group_by = select._group_by_clause._compiler_dispatch(
6✔
353
            self, **kw)
354

355
        if group_by:
6✔
356
            text = " GROUP BY " + group_by
6✔
357

358
            if getattr(select, '_with_cube', False):
6✔
359
                text += " WITH CUBE"
6✔
360

361
            if getattr(select, '_with_rollup', False):
6✔
362
                text += " WITH ROLLUP"
6✔
363

364
            if getattr(select, '_with_totals', False):
6✔
365
                text += " WITH TOTALS"
6✔
366

367
        return text
6✔
368

369
    def visit_delete(self, delete_stmt, **kw):
6✔
370
        if not self.dialect.supports_delete:
6✔
371
            raise exc.CompileError(
6✔
372
                'ALTER DELETE is not supported by this server version'
373
            )
374

375
        compile_state = delete_stmt._compile_state_factory(
6✔
376
            delete_stmt, self, **kw
377
        )
378
        delete_stmt = compile_state.statement
6✔
379

380
        extra_froms = compile_state._extra_froms
6✔
381

382
        correlate_froms = {delete_stmt.table}.union(extra_froms)
6✔
383
        self.stack.append(
6✔
384
            {
385
                "correlate_froms": correlate_froms,
386
                "asfrom_froms": correlate_froms,
387
                "selectable": delete_stmt,
388
            }
389
        )
390

391
        text = "ALTER TABLE "
6✔
392

393
        table_text = self.delete_table_clause(
6✔
394
            delete_stmt, delete_stmt.table, extra_froms
395
        )
396

397
        text += table_text + " DELETE"
6✔
398

399
        if delete_stmt._where_criteria:
6✔
400
            t = self._generate_delimited_and_list(
6✔
401
                delete_stmt._where_criteria, include_table=False, **kw
402
            )
403
            if t:
6✔
404
                text += " WHERE " + t
6✔
405
        else:
406
            raise exc.CompileError('WHERE clause is required')
6✔
407

408
        self.stack.pop(-1)
6✔
409

410
        return text
6✔
411

412
    def visit_update(self, update_stmt, **kw):
6✔
413
        if not self.dialect.supports_update:
6✔
414
            raise exc.CompileError(
6✔
415
                'ALTER UPDATE is not supported by this server version'
416
            )
417

418
        compile_state = update_stmt._compile_state_factory(
6✔
419
            update_stmt, self, **kw
420
        )
421
        update_stmt = compile_state.statement
6✔
422

423
        render_extra_froms = []
6✔
424
        correlate_froms = {update_stmt.table}
6✔
425

426
        self.stack.append(
6✔
427
            {
428
                "correlate_froms": correlate_froms,
429
                "asfrom_froms": correlate_froms,
430
                "selectable": update_stmt,
431
            }
432
        )
433

434
        text = "ALTER TABLE "
6✔
435

436
        table_text = self.update_tables_clause(
6✔
437
            update_stmt, update_stmt.table, render_extra_froms, **kw
438
        )
439
        crud_params = crud._get_crud_params(
6✔
440
            self, update_stmt, compile_state, True, **kw
441
        )
442

443
        text += table_text
6✔
444
        text += " UPDATE "
6✔
445
        text += ", ".join(
6✔
446
            expr + "=" + value for c,
447
            expr, value, _ in crud_params.single_params)
448

449
        if update_stmt._where_criteria:
6✔
450
            t = self._generate_delimited_and_list(
6✔
451
                update_stmt._where_criteria, include_table=False, **kw
452
            )
453
            if t:
6✔
454
                text += " WHERE " + t
6✔
455
        else:
456
            raise exc.CompileError('WHERE clause is required')
6✔
457

458
        self.stack.pop(-1)
6✔
459

460
        return text
6✔
461

462
    def render_literal_value(self, value, type_):
6✔
463
        if isinstance(value, list):
6✔
464
            return (
6✔
465
                '[' +
466
                ', '.join(self.render_literal_value(
467
                        x, type_api._resolve_value_to_type(x)
468
                    ) for x in value) +
469
                ']'
470
            )
471
        else:
472
            return super(ClickHouseSQLCompiler, self).render_literal_value(
6✔
473
                value, type_
474
            )
475

476
    def _get_regexp_args(self, binary, kw):
6✔
477
        string = self.process(binary.left, **kw)
6✔
478
        pattern = self.process(binary.right, **kw)
6✔
479
        return string, pattern
6✔
480

481
    def visit_regexp_match_op_binary(self, binary, operator, **kw):
6✔
482
        string, pattern = self._get_regexp_args(binary, kw)
6✔
483
        return "match(%s, %s)" % (string, pattern)
6✔
484

485
    def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6✔
486
        return "NOT %s" % self.visit_regexp_match_op_binary(
6✔
487
            binary,
488
            operator,
489
            **kw
490
        )
491

492
    def visit_ilike_case_insensitive_operand(self, element, **kw):
6✔
493
        return element.element._compiler_dispatch(self, **kw)
×
494

495
    def visit_ilike_op_binary(self, binary, operator, **kw):
6✔
496
        return "%s ILIKE %s" % (
6✔
497
            self.process(binary.left, **kw),
498
            self.process(binary.right, **kw)
499
        )
500

501
    def visit_not_ilike_op_binary(self, binary, operator, **kw):
6✔
502
        return "%s NOT ILIKE %s" % (
6✔
503
            self.process(binary.left, **kw),
504
            self.process(binary.right, **kw)
505
        )
506

507
    def get_select_precolumns(self, select, **kw):
6✔
508
        # Do not call super().get_select_precolumns because
509
        # it will warn/raise when distinct on is present
510
        if select._distinct or select._distinct_on:
6✔
511
            if select._distinct_on:
6✔
512
                return (
6✔
513
                    "DISTINCT ON ("
514
                    + ", ".join(
515
                        [
516
                            self.process(col, **kw)
517
                            for col in select._distinct_on
518
                        ]
519
                    )
520
                    + ") "
521
                )
522
            else:
523
                return "DISTINCT "
×
524
        else:
525
            return ""
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc