• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

xzkostyan / clickhouse-sqlalchemy / 11405726556

18 Oct 2024 02:28PM UTC coverage: 85.988% (-0.09%) from 86.082%
11405726556

push

github

xzkostyan
Version bumped to 0.2.7

1 of 1 new or added line in 1 file covered. (100.0%)

63 existing lines in 11 files now uncovered.

2332 of 2712 relevant lines covered (85.99%)

20.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.06
/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py
1
from sqlalchemy import exc, literal_column
16✔
2
from sqlalchemy.sql import compiler, elements, COLLECT_CARTESIAN_PRODUCTS, \
24✔
3
    WARN_LINTING, crud
4
from sqlalchemy.sql import type_api
24✔
5
from sqlalchemy.util import inspect_getfullargspec
24✔
6

7
from ... import types
24✔
8

9

10
class ClickHouseSQLCompiler(compiler.SQLCompiler):
24✔
11
    def visit_mod_binary(self, binary, operator, **kw):
24✔
12
        return self.process(binary.left, **kw) + ' %% ' + \
24✔
13
            self.process(binary.right, **kw)
14

15
    def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
24✔
16
        """
17
        Implementation of distinctness comparison in ClickHouse SQL.
18
        A distinctness comparison treats NULL as if it is a (singleton)
19
        value and is what ClickHouse uses for `SELECT DISTINCT` and `GROUP BY`.
20
        Some databases have direct support for a `IS DISTINCT` comparison, but
21
        ClickHouse does not, so we rely on the `hasAny` array function here.
22
        """
23

24
        return "hasAny([%s], [%s])" % (
24✔
25
            self.process(binary.left, **kw),
26
            self.process(binary.right, **kw),
27
        )
28

29
    def visit_is_distinct_from_binary(self, binary, operator, **kw):
24✔
30
        return "NOT %s" % self.visit_is_not_distinct_from_binary(
24✔
31
            binary, operator, **kw
32
        )
33

34
    def visit_empty_set_expr(self, element_types):
24✔
UNCOV
35
        return "SELECT %s WHERE 1!=1" % (
×
36
            ", ".join(
37
                "CAST(NULL AS %s)"
38
                % self.dialect.type_compiler.process(
39
                    t if isinstance(t, types.Nullable) else types.Nullable(t)
40
                )
41
                for t in element_types or [types.Int8()]
42
            ),
43
        )
44

45
    def post_process_text(self, text):
24✔
46
        return text.replace('%', '%%')
24✔
47

48
    def visit_count_func(self, fn, **kw):
24✔
49
        # count accepts zero arguments.
50
        return 'count%s' % self.process(fn.clause_expr, **kw)
24✔
51

52
    def visit_case(self, clause, **kwargs):
24✔
53
        text = 'CASE '
24✔
54
        if clause.value is not None:
24✔
UNCOV
55
            text += clause.value._compiler_dispatch(self, **kwargs) + ' '
×
56
        for cond, result in clause.whens:
24✔
57
            text += 'WHEN ' + cond._compiler_dispatch(
24✔
58
                self, **kwargs
59
            ) + ' THEN ' + result._compiler_dispatch(
60
                self, **kwargs) + " "
61

62
        if clause.else_ is not None:
24✔
63
            text += 'ELSE ' + clause.else_._compiler_dispatch(
24✔
64
                self, **kwargs
65
            ) + ' '
66

67
        text += 'END'
24✔
68
        return text
24✔
69

70
    def visit_if__func(self, func, **kw):
24✔
UNCOV
71
        return "(%s) ? (%s) : (%s)" % (
×
72
            self.process(func.clauses.clauses[0], **kw),
73
            self.process(func.clauses.clauses[1], **kw),
74
            self.process(func.clauses.clauses[2], **kw)
75
        )
76

77
    def limit_by_clause(self, select, **kw):
24✔
78
        text = ''
24✔
79
        limit_by_clause = select._limit_by_clause
24✔
80
        if limit_by_clause:
24✔
81
            text += ' LIMIT '
24✔
82
            if limit_by_clause.offset is not None:
24✔
83
                text += self.process(limit_by_clause.offset, **kw) + ', '
24✔
84
            text += self.process(limit_by_clause.limit, **kw)
24✔
85
            limit_by_exprs = limit_by_clause.by_clauses._compiler_dispatch(
24✔
86
                self, **kw
87
            )
88
            text += ' BY ' + limit_by_exprs
24✔
89

90
        return text
24✔
91

92
    def limit_clause(self, select, **kw):
24✔
93
        text = ''
24✔
94
        if select._limit_clause is not None:
24✔
95
            text += ' \n LIMIT '
24✔
96
            if select._offset_clause is not None:
24✔
97
                text += self.process(select._offset_clause, **kw) + ', '
24✔
98
            text += self.process(select._limit_clause, **kw)
24✔
99
        else:
100
            if select._offset_clause is not None:
24✔
101
                raise exc.CompileError('OFFSET without LIMIT is not supported')
24✔
102

103
        return text
24✔
104

105
    def visit_lambda(self, lambda_, **kw):
24✔
106
        func = lambda_.func
24✔
107
        spec = inspect_getfullargspec(func)
24✔
108

109
        if spec.varargs:
24✔
110
            raise exc.CompileError('Lambdas with *args are not supported')
24✔
111

112
        try:
24✔
113
            # ArgSpec in SA>=1.3.0b2
114
            keywords = spec.keywords
24✔
115
        except AttributeError:
24✔
116
            # FullArgSpec in SA>=1.3.0b2
117
            keywords = spec.varkw
24✔
118

119
        if keywords:
24✔
120
            raise exc.CompileError('Lambdas with **kwargs are not supported')
24✔
121

122
        text = ', '.join(spec.args) + ' -> '
24✔
123

124
        args = [literal_column(arg) for arg in spec.args]
24✔
125
        text += self.process(func(*args), **kw)
24✔
126

127
        return text
24✔
128

129
    def visit_extract(self, extract, **kw):
24✔
130
        field = self.extract_map.get(extract.field, extract.field)
24✔
131
        column = self.process(extract.expr, **kw)
24✔
132
        if field == 'year':
24✔
133
            return 'toYear(%s)' % column
24✔
134
        elif field == 'month':
24✔
135
            return 'toMonth(%s)' % column
24✔
136
        elif field == 'day':
24✔
137
            return 'toDayOfMonth(%s)' % column
24✔
138
        else:
139
            return column
24✔
140

141
    def visit_join(self, join, asfrom=False, **kwargs):
24✔
142
        text = join.left._compiler_dispatch(self, asfrom=True, **kwargs)
24✔
143

144
        if text[0] == '(' and text[-1] == ')':
24✔
UNCOV
145
            text = text[1:-1]
×
146

147
        flags = join.full
24✔
148
        if not isinstance(flags, dict):
24✔
149
            if isinstance(flags, tuple):
24✔
150
                flags = dict(flags)
24✔
151
            else:
152
                flags = {'full': flags}
24✔
153
        # need to make a variable to prevent leaks in some debuggers
154
        join_type = flags.get('type')
24✔
155
        if join_type is None:
24✔
156
            if flags.get('full'):
24✔
157
                join_type = 'FULL OUTER'
24✔
158
            elif join.isouter:
24✔
159
                join_type = 'LEFT OUTER'
24✔
160
            else:
161
                join_type = 'INNER'
24✔
162
        else:
163
            join_type = join_type.upper()
24✔
164
            if join.isouter and 'INNER' in join_type:
24✔
165
                raise exc.CompileError(
24✔
166
                    "can't compile join with specified "
167
                    "INNER type and isouter=True"
168
                )
169
            # isouter=False by default, disable that checking
170
            # elif not join.isouter and 'OUTER' in join.type:
171
            #     raise exc.CompileError(
172
            #         "can't compile join with specified "
173
            #         "OUTER type and isouter=False"
174
            #     )
175

176
        strictness = flags.get('strictness')
24✔
177
        if strictness:
24✔
178
            join_type = strictness.upper() + ' ' + join_type
24✔
179

180
        distribution = flags.get('distribution')
24✔
181
        if distribution:
24✔
182
            join_type = distribution.upper() + ' ' + join_type
24✔
183

184
        if join_type is not None:
24✔
185
            text += ' ' + join_type.upper() + ' JOIN '
24✔
186

187
        onclause = join.onclause
24✔
188

189
        text += join.right._compiler_dispatch(self, asfrom=True, **kwargs)
24✔
190
        if isinstance(onclause, elements.Tuple):
24✔
191
            text += ' USING ' + onclause._compiler_dispatch(
24✔
192
                self, include_table=False, **kwargs
193
            )
194
        else:
195
            text += ' ON ' + onclause._compiler_dispatch(self, **kwargs)
24✔
196
        return text
24✔
197

198
    def visit_array_join(self, array_join, **kwargs):
24✔
199
        kwargs['within_columns_clause'] = True
24✔
200

201
        return ' \nARRAY JOIN {columns}'.format(
24✔
202
            columns=', '.join(
203
                col._compiler_dispatch(self,
204
                                       within_label_clause=False,
205
                                       **kwargs)
206
                for col in array_join.clauses
207

208
            )
209
        )
210

211
    def visit_left_array_join(self, array_join, **kwargs):
24✔
212
        kwargs['within_columns_clause'] = True
24✔
213

214
        return ' \nLEFT ARRAY JOIN {columns}'.format(
24✔
215
            columns=', '.join(
216
                col._compiler_dispatch(self,
217
                                       within_label_clause=False,
218
                                       **kwargs)
219
                for col in array_join.clauses
220

221
            )
222
        )
223

224
    def visit_label(self,
24✔
225
                    label,
226
                    from_labeled_label=False,
227
                    **kw):
228
        if from_labeled_label:
24✔
229
            return super(ClickHouseSQLCompiler, self).visit_label(
24✔
230
                label,
231
                render_label_as_label=label
232
            )
233
        else:
234
            return super(ClickHouseSQLCompiler, self).visit_label(
24✔
235
                label,
236
                **kw
237
            )
238

239
    def _compose_select_body(
24✔
240
        self,
241
        text,
242
        select,
243
        compile_state,
244
        inner_columns,
245
        froms,
246
        byfrom,
247
        toplevel,
248
        kwargs,
249
    ):
250
        text += ", ".join(inner_columns)
24✔
251

252
        if self.linting & COLLECT_CARTESIAN_PRODUCTS:
24✔
253
            from_linter = compiler.FromLinter({}, set())
24✔
254
            warn_linting = self.linting & WARN_LINTING
24✔
255
            if toplevel:
24✔
256
                self.from_linter = from_linter
24✔
257
        else:
258
            from_linter = None
24✔
259
            warn_linting = False
24✔
260

261
        if froms:
24✔
262
            text += " \nFROM "
24✔
263

264
            if select._hints:
24✔
UNCOV
265
                text += ", ".join(
×
266
                    [
267
                        f._compiler_dispatch(
268
                            self,
269
                            asfrom=True,
270
                            fromhints=byfrom,
271
                            from_linter=from_linter,
272
                            **kwargs
273
                        )
274
                        for f in froms
275
                    ]
276
                )
277
            else:
278
                text += ", ".join(
24✔
279
                    [
280
                        f._compiler_dispatch(
281
                            self,
282
                            asfrom=True,
283
                            from_linter=from_linter,
284
                            **kwargs
285
                        )
286
                        for f in froms
287
                    ]
288
                )
289
        else:
290
            text += self.default_from()
24✔
291

292
        sample_clause = getattr(select, '_sample_clause', None)
24✔
293

294
        if sample_clause is not None:
24✔
295
            text += self.sample_clause(select, **kwargs)
24✔
296

297
        if getattr(select, '_array_join', None) is not None:
24✔
298
            text += select._array_join._compiler_dispatch(self, **kwargs)
24✔
299

300
        final_clause = getattr(select, '_final_clause', None)
24✔
301

302
        if final_clause is not None:
24✔
303
            text += self.final_clause()
24✔
304

305
        if select._where_criteria:
24✔
306
            t = self._generate_delimited_and_list(
24✔
307
                select._where_criteria, from_linter=from_linter, **kwargs
308
            )
309
            if t:
24✔
310
                text += " \nWHERE " + t
24✔
311

312
        if warn_linting:
24✔
313
            from_linter.warn()
24✔
314

315
        if select._group_by_clauses:
24✔
316
            text += self.group_by_clause(select, **kwargs)
24✔
317

318
        if select._having_criteria:
24✔
319
            t = self._generate_delimited_and_list(
24✔
320
                select._having_criteria, **kwargs
321
            )
322
            if t:
24✔
323
                text += " \nHAVING " + t
24✔
324

325
        if select._order_by_clauses:
24✔
326
            text += self.order_by_clause(select, **kwargs)
24✔
327

328
        limit_by_clause = getattr(select, '_limit_by_clause', None)
24✔
329

330
        if limit_by_clause is not None:
24✔
331
            text += self.limit_by_clause(select, **kwargs)
24✔
332

333
        if select._has_row_limiting_clause:
24✔
334
            text += self._row_limit_clause(select, **kwargs)
24✔
335

336
        if select._for_update_arg is not None:
24✔
UNCOV
337
            text += self.for_update_clause(select, **kwargs)
×
338

339
        return text
24✔
340

341
    def sample_clause(self, select, **kw):
24✔
342
        return " \nSAMPLE " + self.process(select._sample_clause, **kw)
24✔
343

344
    def final_clause(self):
24✔
345
        return " \nFINAL"
24✔
346

347
    def group_by_clause(self, select, **kw):
24✔
348
        text = ""
24✔
349

350
        group_by = select._group_by_clause._compiler_dispatch(
24✔
351
            self, **kw)
352

353
        if group_by:
24✔
354
            text = " GROUP BY " + group_by
24✔
355

356
            if getattr(select, '_with_cube', False):
24✔
357
                text += " WITH CUBE"
24✔
358

359
            if getattr(select, '_with_rollup', False):
24✔
360
                text += " WITH ROLLUP"
24✔
361

362
            if getattr(select, '_with_totals', False):
24✔
363
                text += " WITH TOTALS"
24✔
364

365
        return text
24✔
366

367
    def visit_delete(self, delete_stmt, **kw):
24✔
368
        if not self.dialect.supports_delete:
24✔
369
            raise exc.CompileError(
24✔
370
                'ALTER DELETE is not supported by this server version'
371
            )
372

373
        compile_state = delete_stmt._compile_state_factory(
24✔
374
            delete_stmt, self, **kw
375
        )
376
        delete_stmt = compile_state.statement
24✔
377

378
        extra_froms = compile_state._extra_froms
24✔
379

380
        correlate_froms = {delete_stmt.table}.union(extra_froms)
24✔
381
        self.stack.append(
24✔
382
            {
383
                "correlate_froms": correlate_froms,
384
                "asfrom_froms": correlate_froms,
385
                "selectable": delete_stmt,
386
            }
387
        )
388

389
        text = "ALTER TABLE "
24✔
390

391
        table_text = self.delete_table_clause(
24✔
392
            delete_stmt, delete_stmt.table, extra_froms
393
        )
394

395
        text += table_text + " DELETE"
24✔
396

397
        if delete_stmt._where_criteria:
24✔
398
            t = self._generate_delimited_and_list(
24✔
399
                delete_stmt._where_criteria, include_table=False, **kw
400
            )
401
            if t:
24✔
402
                text += " WHERE " + t
24✔
403
        else:
404
            raise exc.CompileError('WHERE clause is required')
24✔
405

406
        self.stack.pop(-1)
24✔
407

408
        return text
24✔
409

410
    def visit_update(self, update_stmt, **kw):
24✔
411
        if not self.dialect.supports_update:
24✔
412
            raise exc.CompileError(
24✔
413
                'ALTER UPDATE is not supported by this server version'
414
            )
415

416
        compile_state = update_stmt._compile_state_factory(
24✔
417
            update_stmt, self, **kw
418
        )
419
        update_stmt = compile_state.statement
24✔
420

421
        render_extra_froms = []
24✔
422
        correlate_froms = {update_stmt.table}
24✔
423

424
        self.stack.append(
24✔
425
            {
426
                "correlate_froms": correlate_froms,
427
                "asfrom_froms": correlate_froms,
428
                "selectable": update_stmt,
429
            }
430
        )
431

432
        text = "ALTER TABLE "
24✔
433

434
        table_text = self.update_tables_clause(
24✔
435
            update_stmt, update_stmt.table, render_extra_froms, **kw
436
        )
437
        crud_params = crud._get_crud_params(
24✔
438
            self, update_stmt, compile_state, **kw
439
        )
440

441
        text += table_text
24✔
442
        text += " UPDATE "
24✔
443

444
        text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
24✔
445

446
        if update_stmt._where_criteria:
24✔
447
            t = self._generate_delimited_and_list(
24✔
448
                update_stmt._where_criteria, include_table=False, **kw
449
            )
450
            if t:
24✔
451
                text += " WHERE " + t
24✔
452
        else:
453
            raise exc.CompileError('WHERE clause is required')
24✔
454

455
        self.stack.pop(-1)
24✔
456

457
        return text
24✔
458

459
    def render_literal_value(self, value, type_):
24✔
460
        if isinstance(value, list):
24✔
461
            return (
24✔
462
                '[' +
463
                ', '.join(self.render_literal_value(
464
                        x, type_api._resolve_value_to_type(x)
465
                    ) for x in value) +
466
                ']'
467
            )
468
        else:
469
            return super(ClickHouseSQLCompiler, self).render_literal_value(
24✔
470
                value, type_
471
            )
472

473
    def _get_regexp_args(self, binary, kw):
24✔
474
        string = self.process(binary.left, **kw)
24✔
475
        pattern = self.process(binary.right, **kw)
24✔
476
        return string, pattern
24✔
477

478
    def visit_regexp_match_op_binary(self, binary, operator, **kw):
24✔
479
        string, pattern = self._get_regexp_args(binary, kw)
24✔
480
        return "MATCH(%s, %s)" % (string, pattern)
24✔
481

482
    def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
24✔
483
        return "NOT %s" % self.visit_regexp_match_op_binary(
24✔
484
            binary,
485
            operator,
486
            **kw
487
        )
488

489
    def visit_ilike_case_insensitive_operand(self, element, **kw):
24✔
490
        return element.element._compiler_dispatch(self, **kw)
×
491

492
    def visit_ilike_op_binary(self, binary, operator, **kw):
24✔
493
        return "%s ILIKE %s" % (
24✔
494
            self.process(binary.left, **kw),
495
            self.process(binary.right, **kw)
496
        )
497

498
    def visit_not_ilike_op_binary(self, binary, operator, **kw):
24✔
499
        return "%s NOT ILIKE %s" % (
24✔
500
            self.process(binary.left, **kw),
501
            self.process(binary.right, **kw)
502
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc