• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 9791636125

04 Jul 2024 09:10AM UTC coverage: 89.523% (+0.1%) from 89.416%
9791636125

push

gh-ci

ostafen
Implement CHECK constraints

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>

554 of 628 new or added lines in 10 files covered. (88.22%)

14 existing lines in 7 files now uncovered.

35374 of 39514 relevant lines covered (89.52%)

160396.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.54
/embedded/sql/parser.go
1
/*
2
Copyright 2024 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "encoding/hex"
22
        "errors"
23
        "fmt"
24
        "io"
25
        "strconv"
26
        "strings"
27
)
28

29
//go:generate go run golang.org/x/tools/cmd/goyacc -l -o sql_parser.go sql_grammar.y
30

31
var reservedWords = map[string]int{
32
        "CREATE":         CREATE,
33
        "DROP":           DROP,
34
        "USE":            USE,
35
        "DATABASE":       DATABASE,
36
        "SNAPSHOT":       SNAPSHOT,
37
        "HISTORY":        HISTORY,
38
        "OF":             OF,
39
        "SINCE":          SINCE,
40
        "AFTER":          AFTER,
41
        "BEFORE":         BEFORE,
42
        "UNTIL":          UNTIL,
43
        "TABLE":          TABLE,
44
        "PRIMARY":        PRIMARY,
45
        "KEY":            KEY,
46
        "UNIQUE":         UNIQUE,
47
        "INDEX":          INDEX,
48
        "ON":             ON,
49
        "ALTER":          ALTER,
50
        "ADD":            ADD,
51
        "RENAME":         RENAME,
52
        "TO":             TO,
53
        "COLUMN":         COLUMN,
54
        "INSERT":         INSERT,
55
        "CONFLICT":       CONFLICT,
56
        "DO":             DO,
57
        "NOTHING":        NOTHING,
58
        "UPSERT":         UPSERT,
59
        "INTO":           INTO,
60
        "VALUES":         VALUES,
61
        "UPDATE":         UPDATE,
62
        "SET":            SET,
63
        "DELETE":         DELETE,
64
        "BEGIN":          BEGIN,
65
        "TRANSACTION":    TRANSACTION,
66
        "COMMIT":         COMMIT,
67
        "ROLLBACK":       ROLLBACK,
68
        "SELECT":         SELECT,
69
        "DISTINCT":       DISTINCT,
70
        "FROM":           FROM,
71
        "UNION":          UNION,
72
        "ALL":            ALL,
73
        "TX":             TX,
74
        "JOIN":           JOIN,
75
        "HAVING":         HAVING,
76
        "WHERE":          WHERE,
77
        "GROUP":          GROUP,
78
        "BY":             BY,
79
        "LIMIT":          LIMIT,
80
        "OFFSET":         OFFSET,
81
        "ORDER":          ORDER,
82
        "AS":             AS,
83
        "ASC":            ASC,
84
        "DESC":           DESC,
85
        "NOT":            NOT,
86
        "LIKE":           LIKE,
87
        "EXISTS":         EXISTS,
88
        "IN":             IN,
89
        "AUTO_INCREMENT": AUTO_INCREMENT,
90
        "NULL":           NULL,
91
        "IF":             IF,
92
        "IS":             IS,
93
        "CAST":           CAST,
94
        "::":             SCAST,
95
        "SHOW":           SHOW,
96
        "DATABASES":      DATABASES,
97
        "TABLES":         TABLES,
98
        "USERS":          USERS,
99
        "USER":           USER,
100
        "WITH":           WITH,
101
        "PASSWORD":       PASSWORD,
102
        "READ":           READ,
103
        "READWRITE":      READWRITE,
104
        "ADMIN":          ADMIN,
105
        "CHECK":          CHECK,
106
        "CONSTRAINT":     CONSTRAINT,
107
}
108

109
var joinTypes = map[string]JoinType{
110
        "INNER": InnerJoin,
111
        "LEFT":  LeftJoin,
112
        "RIGHT": RightJoin,
113
}
114

115
var types = map[string]SQLValueType{
116
        "INTEGER":   IntegerType,
117
        "BOOLEAN":   BooleanType,
118
        "VARCHAR":   VarcharType,
119
        "UUID":      UUIDType,
120
        "BLOB":      BLOBType,
121
        "TIMESTAMP": TimestampType,
122
        "FLOAT":     Float64Type,
123
        "JSON":      JSONType,
124
}
125

126
var aggregateFns = map[string]AggregateFn{
127
        "COUNT": COUNT,
128
        "SUM":   SUM,
129
        "MAX":   MAX,
130
        "MIN":   MIN,
131
        "AVG":   AVG,
132
}
133

134
var boolValues = map[string]bool{
135
        "TRUE":  true,
136
        "FALSE": false,
137
}
138

139
var cmpOps = map[string]CmpOperator{
140
        "=":  EQ,
141
        "!=": NE,
142
        "<>": NE,
143
        "<":  LT,
144
        "<=": LE,
145
        ">":  GT,
146
        ">=": GE,
147
}
148

149
var logicOps = map[string]LogicOperator{
150
        "AND": AND,
151
        "OR":  OR,
152
}
153

154
var ErrEitherNamedOrUnnamedParams = errors.New("either named or unnamed params")
155
var ErrEitherPosOrNonPosParams = errors.New("either positional or non-positional named params")
156
var ErrInvalidPositionalParameter = errors.New("invalid positional parameter")
157

158
type positionalParamType int
159

160
const (
161
        NamedNonPositionalParamType positionalParamType = iota + 1
162
        NamedPositionalParamType
163
        UnnamedParamType
164
)
165

166
type lexer struct {
167
        r               *aheadByteReader
168
        err             error
169
        namedParamsType positionalParamType
170
        paramsCount     int
171
        result          []SQLStmt
172
}
173

174
type aheadByteReader struct {
175
        nextChar  byte
176
        nextErr   error
177
        r         io.ByteReader
178
        readCount int
179
}
180

181
func newAheadByteReader(r io.ByteReader) *aheadByteReader {
3,048✔
182
        ar := &aheadByteReader{r: r}
3,048✔
183
        ar.nextChar, ar.nextErr = r.ReadByte()
3,048✔
184
        return ar
3,048✔
185
}
3,048✔
186

187
func (ar *aheadByteReader) ReadByte() (byte, error) {
225,504✔
188
        defer func() {
451,008✔
189
                if ar.nextErr == nil {
448,097✔
190
                        ar.nextChar, ar.nextErr = ar.r.ReadByte()
222,593✔
191
                }
222,593✔
192
        }()
193

194
        ar.readCount++
225,504✔
195

225,504✔
196
        return ar.nextChar, ar.nextErr
225,504✔
197
}
198

199
func (ar *aheadByteReader) ReadCount() int {
41✔
200
        return ar.readCount
41✔
201
}
41✔
202

203
func (ar *aheadByteReader) NextByte() (byte, error) {
177,832✔
204
        return ar.nextChar, ar.nextErr
177,832✔
205
}
177,832✔
206

207
func ParseSQLString(sql string) ([]SQLStmt, error) {
208✔
208
        return ParseSQL(strings.NewReader(sql))
208✔
209
}
208✔
210

211
func ParseSQL(r io.ByteReader) ([]SQLStmt, error) {
3,048✔
212
        lexer := newLexer(r)
3,048✔
213

3,048✔
214
        yyParse(lexer)
3,048✔
215

3,048✔
216
        return lexer.result, lexer.err
3,048✔
217
}
3,048✔
218

219
func ParseExpFromString(exp string) (ValueExp, error) {
108✔
220
        stmt := fmt.Sprintf("SELECT * FROM t WHERE %s", exp)
108✔
221

108✔
222
        res, err := ParseSQLString(stmt)
108✔
223
        if err != nil {
108✔
NEW
224
                return nil, err
×
NEW
225
        }
×
226

227
        s := res[0].(*SelectStmt)
108✔
228
        return s.where, nil
108✔
229
}
230

231
func newLexer(r io.ByteReader) *lexer {
3,048✔
232
        return &lexer{
3,048✔
233
                r:   newAheadByteReader(r),
3,048✔
234
                err: nil,
3,048✔
235
        }
3,048✔
236
}
3,048✔
237

238
func (l *lexer) Lex(lval *yySymType) int {
52,187✔
239
        var ch byte
52,187✔
240
        var err error
52,187✔
241

52,187✔
242
        for {
133,088✔
243
                ch, err = l.r.ReadByte()
80,901✔
244
                if err == io.EOF {
83,812✔
245
                        return 0
2,911✔
246
                }
2,911✔
247
                if err != nil {
77,990✔
248
                        lval.err = err
×
249
                        return ERROR
×
250
                }
×
251

252
                if ch == '\t' {
80,991✔
253
                        continue
3,001✔
254
                }
255

256
                if ch == '/' && l.r.nextChar == '*' {
74,992✔
257
                        l.r.ReadByte()
3✔
258

3✔
259
                        for {
102✔
260
                                ch, err := l.r.ReadByte()
99✔
261
                                if err == io.EOF {
99✔
262
                                        break
×
263
                                }
264
                                if err != nil {
99✔
265
                                        lval.err = err
×
266
                                        return ERROR
×
267
                                }
×
268

269
                                if ch == '*' && l.r.nextChar == '/' {
102✔
270
                                        l.r.ReadByte() // consume closing slash
3✔
271
                                        break
3✔
272
                                }
273
                        }
274

275
                        continue
3✔
276
                }
277

278
                if isLineBreak(ch) {
75,902✔
279
                        if ch == '\r' && l.r.nextChar == '\n' {
917✔
280
                                l.r.ReadByte()
1✔
281
                        }
1✔
282
                        continue
916✔
283
                }
284

285
                if !isSpace(ch) {
123,346✔
286
                        break
49,276✔
287
                }
288
        }
289

290
        if isSeparator(ch) {
49,671✔
291
                return STMT_SEPARATOR
395✔
292
        }
395✔
293

294
        if ch == '-' && l.r.nextChar == '>' {
48,931✔
295
                l.r.ReadByte()
50✔
296
                return ARROW
50✔
297
        }
50✔
298

299
        if isBLOBPrefix(ch) && isQuote(l.r.nextChar) {
48,925✔
300
                l.r.ReadByte() // consume starting quote
94✔
301

94✔
302
                tail, err := l.readString()
94✔
303
                if err != nil {
94✔
304
                        lval.err = err
×
305
                        return ERROR
×
306
                }
×
307

308
                val, err := hex.DecodeString(tail)
94✔
309
                if err != nil {
94✔
310
                        lval.err = err
×
311
                        return ERROR
×
312
                }
×
313

314
                lval.blob = val
94✔
315
                return BLOB
94✔
316
        }
317

318
        if isLetter(ch) {
71,650✔
319
                tail, err := l.readWord()
22,913✔
320
                if err != nil {
22,913✔
321
                        lval.err = err
×
322
                        return ERROR
×
323
                }
×
324

325
                w := fmt.Sprintf("%c%s", ch, tail)
22,913✔
326
                tid := strings.ToUpper(w)
22,913✔
327

22,913✔
328
                sqlType, ok := types[tid]
22,913✔
329
                if ok {
23,544✔
330
                        lval.sqlType = sqlType
631✔
331
                        return TYPE
631✔
332
                }
631✔
333

334
                val, ok := boolValues[tid]
22,282✔
335
                if ok {
22,472✔
336
                        lval.boolean = val
190✔
337
                        return BOOLEAN
190✔
338
                }
190✔
339

340
                lop, ok := logicOps[tid]
22,092✔
341
                if ok {
22,240✔
342
                        lval.logicOp = lop
148✔
343
                        return LOP
148✔
344
                }
148✔
345

346
                afn, ok := aggregateFns[tid]
21,944✔
347
                if ok {
22,077✔
348
                        lval.aggFn = afn
133✔
349
                        return AGGREGATE_FUNC
133✔
350
                }
133✔
351

352
                join, ok := joinTypes[tid]
21,811✔
353
                if ok {
21,828✔
354
                        lval.joinType = join
17✔
355
                        return JOINTYPE
17✔
356
                }
17✔
357

358
                tkn, ok := reservedWords[tid]
21,794✔
359
                if ok {
32,096✔
360
                        return tkn
10,302✔
361
                }
10,302✔
362

363
                lval.id = strings.ToLower(w)
11,492✔
364

11,492✔
365
                return IDENTIFIER
11,492✔
366
        }
367

368
        if isDoubleQuote(ch) {
25,831✔
369
                tail, err := l.readWord()
7✔
370
                if err != nil {
7✔
371
                        lval.err = err
×
372
                        return ERROR
×
373
                }
×
374

375
                if !isDoubleQuote(l.r.nextChar) {
8✔
376
                        lval.err = fmt.Errorf("double quote expected")
1✔
377
                        return ERROR
1✔
378
                }
1✔
379

380
                l.r.ReadByte() // consume ending quote
6✔
381

6✔
382
                lval.id = strings.ToLower(tail)
6✔
383
                return IDENTIFIER
6✔
384
        }
385

386
        if isNumber(ch) {
26,688✔
387
                tail, err := l.readNumber()
871✔
388
                if err != nil {
871✔
389
                        lval.err = err
×
390
                        return ERROR
×
391
                }
×
392
                // looking for a float
393
                if isDot(l.r.nextChar) {
915✔
394
                        l.r.ReadByte() // consume dot
44✔
395

44✔
396
                        decimalPart, err := l.readNumber()
44✔
397
                        if err != nil {
44✔
398
                                lval.err = err
×
399
                                return ERROR
×
400
                        }
×
401

402
                        val, err := strconv.ParseFloat(fmt.Sprintf("%c%s.%s", ch, tail, decimalPart), 64)
44✔
403
                        if err != nil {
45✔
404
                                lval.err = err
1✔
405
                                return ERROR
1✔
406
                        }
1✔
407

408
                        lval.float = val
43✔
409
                        return FLOAT
43✔
410
                }
411

412
                val, err := strconv.ParseUint(fmt.Sprintf("%c%s", ch, tail), 10, 64)
827✔
413
                if err != nil {
827✔
414
                        lval.err = err
×
415
                        return ERROR
×
416
                }
×
417

418
                lval.integer = val
827✔
419
                return INTEGER
827✔
420
        }
421

422
        if isComparison(ch) {
25,389✔
423
                tail, err := l.readComparison()
443✔
424
                if err != nil {
443✔
425
                        lval.err = err
×
426
                        return ERROR
×
427
                }
×
428

429
                op := fmt.Sprintf("%c%s", ch, tail)
443✔
430

443✔
431
                cmpOp, ok := cmpOps[op]
443✔
432
                if !ok {
443✔
433
                        lval.err = fmt.Errorf("invalid comparison operator %s", op)
×
434
                        return ERROR
×
435
                }
×
436

437
                lval.cmpOp = cmpOp
443✔
438
                return CMPOP
443✔
439
        }
440

441
        if isQuote(ch) {
25,078✔
442
                tail, err := l.readString()
575✔
443
                if err != nil {
575✔
444
                        lval.err = err
×
445
                        return ERROR
×
446
                }
×
447

448
                lval.str = tail
575✔
449
                return VARCHAR
575✔
450
        }
451

452
        if ch == ':' {
23,930✔
453
                ch, err := l.r.ReadByte()
2✔
454
                if err != nil {
2✔
455
                        lval.err = err
×
456
                        return ERROR
×
457
                }
×
458

459
                if ch != ':' {
2✔
460
                        lval.err = fmt.Errorf("colon expected")
×
461
                        return ERROR
×
462
                }
×
463

464
                return SCAST
2✔
465
        }
466

467
        if ch == '@' {
28,533✔
468
                if l.namedParamsType == UnnamedParamType {
4,608✔
469
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
470
                        return ERROR
1✔
471
                }
1✔
472

473
                if l.namedParamsType == NamedPositionalParamType {
4,607✔
474
                        lval.err = ErrEitherPosOrNonPosParams
1✔
475
                        return ERROR
1✔
476
                }
1✔
477

478
                l.namedParamsType = NamedNonPositionalParamType
4,605✔
479

4,605✔
480
                ch, err := l.r.NextByte()
4,605✔
481
                if err != nil {
4,605✔
482
                        lval.err = err
×
483
                        return ERROR
×
484
                }
×
485

486
                if !isLetter(ch) {
4,605✔
487
                        return ERROR
×
488
                }
×
489

490
                id, err := l.readWord()
4,605✔
491
                if err != nil {
4,605✔
492
                        lval.err = err
×
493
                        return ERROR
×
494
                }
×
495

496
                lval.id = strings.ToLower(id)
4,605✔
497

4,605✔
498
                return NPARAM
4,605✔
499
        }
500

501
        if ch == '$' {
19,361✔
502
                if l.namedParamsType == UnnamedParamType {
43✔
503
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
504
                        return ERROR
1✔
505
                }
1✔
506

507
                if l.namedParamsType == NamedNonPositionalParamType {
42✔
508
                        lval.err = ErrEitherPosOrNonPosParams
1✔
509
                        return ERROR
1✔
510
                }
1✔
511

512
                id, err := l.readNumber()
40✔
513
                if err != nil {
40✔
514
                        lval.err = err
×
515
                        return ERROR
×
516
                }
×
517

518
                pid, err := strconv.Atoi(id)
40✔
519
                if err != nil {
41✔
520
                        lval.err = err
1✔
521
                        return ERROR
1✔
522
                }
1✔
523

524
                if pid < 1 {
40✔
525
                        lval.err = ErrInvalidPositionalParameter
1✔
526
                        return ERROR
1✔
527
                }
1✔
528

529
                lval.pparam = pid
38✔
530

38✔
531
                l.namedParamsType = NamedPositionalParamType
38✔
532

38✔
533
                return PPARAM
38✔
534
        }
535

536
        if ch == '?' {
19,416✔
537
                if l.namedParamsType == NamedNonPositionalParamType || l.namedParamsType == NamedPositionalParamType {
141✔
538
                        lval.err = ErrEitherNamedOrUnnamedParams
2✔
539
                        return ERROR
2✔
540
                }
2✔
541

542
                l.paramsCount++
137✔
543
                lval.pparam = l.paramsCount
137✔
544

137✔
545
                l.namedParamsType = UnnamedParamType
137✔
546

137✔
547
                return PPARAM
137✔
548
        }
549

550
        if isDot(ch) {
19,243✔
551
                if isNumber(l.r.nextChar) { // looking for  a float
110✔
552
                        decimalPart, err := l.readNumber()
5✔
553
                        if err != nil {
5✔
554
                                lval.err = err
×
555
                                return ERROR
×
556
                        }
×
557
                        val, err := strconv.ParseFloat(fmt.Sprintf("%d.%s", 0, decimalPart), 64)
5✔
558
                        if err != nil {
5✔
559
                                lval.err = err
×
560
                                return ERROR
×
561
                        }
×
562
                        lval.float = val
5✔
563
                        return FLOAT
5✔
564
                }
565
                return DOT
100✔
566
        }
567

568
        return int(ch)
19,033✔
569
}
570

571
func (l *lexer) Error(err string) {
41✔
572
        l.err = fmt.Errorf("%s at position %d", err, l.r.ReadCount())
41✔
573
}
41✔
574

575
func (l *lexer) readWord() (string, error) {
27,525✔
576
        return l.readWhile(func(ch byte) bool {
173,224✔
577
                return isLetter(ch) || isNumber(ch)
145,699✔
578
        })
145,699✔
579
}
580

581
func (l *lexer) readNumber() (string, error) {
960✔
582
        return l.readWhile(isNumber)
960✔
583
}
960✔
584

585
func (l *lexer) readString() (string, error) {
669✔
586
        var b bytes.Buffer
669✔
587

669✔
588
        for {
15,786✔
589
                ch, err := l.r.ReadByte()
15,117✔
590
                if err != nil {
15,117✔
591
                        return "", err
×
592
                }
×
593

594
                nextCh, _ := l.r.NextByte()
15,117✔
595

15,117✔
596
                if isQuote(ch) {
15,788✔
597
                        if isQuote(nextCh) {
673✔
598
                                l.r.ReadByte() // consume escaped quote
2✔
599
                        } else {
671✔
600
                                break // string completely read
669✔
601
                        }
602
                }
603

604
                b.WriteByte(ch)
14,448✔
605
        }
606

607
        return b.String(), nil
669✔
608
}
609

610
func (l *lexer) readComparison() (string, error) {
443✔
611
        return l.readWhile(func(ch byte) bool {
1,018✔
612
                return isComparison(ch)
575✔
613
        })
575✔
614
}
615

616
func (l *lexer) readWhile(condFn func(b byte) bool) (string, error) {
28,928✔
617
        var b bytes.Buffer
28,928✔
618

28,928✔
619
        for {
187,038✔
620
                ch, err := l.r.NextByte()
158,110✔
621
                if err == io.EOF {
158,590✔
622
                        break
480✔
623
                }
624
                if err != nil {
157,630✔
625
                        return "", err
×
626
                }
×
627

628
                if !condFn(ch) {
186,078✔
629
                        break
28,448✔
630
                }
631

632
                ch, _ = l.r.ReadByte()
129,182✔
633
                b.WriteByte(ch)
129,182✔
634
        }
635

636
        return b.String(), nil
28,928✔
637
}
638

639
func isBLOBPrefix(ch byte) bool {
48,831✔
640
        return ch == 'x'
48,831✔
641
}
48,831✔
642

643
func isSeparator(ch byte) bool {
49,276✔
644
        return ch == ';'
49,276✔
645
}
49,276✔
646

647
func isLineBreak(ch byte) bool {
74,986✔
648
        return ch == '\r' || ch == '\n'
74,986✔
649
}
74,986✔
650

651
func isSpace(ch byte) bool {
74,070✔
652
        return ch == 32 || ch == 9 //SPACE or TAB
74,070✔
653
}
74,070✔
654

655
func isNumber(ch byte) bool {
67,178✔
656
        return '0' <= ch && ch <= '9'
67,178✔
657
}
67,178✔
658

659
func isLetter(ch byte) bool {
199,041✔
660
        return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_'
199,041✔
661
}
199,041✔
662

663
func isComparison(ch byte) bool {
25,521✔
664
        return ch == '!' || ch == '<' || ch == '=' || ch == '>'
25,521✔
665
}
25,521✔
666

667
func isQuote(ch byte) bool {
40,388✔
668
        return ch == 0x27
40,388✔
669
}
40,388✔
670

671
func isDoubleQuote(ch byte) bool {
25,831✔
672
        return ch == 0x22
25,831✔
673
}
25,831✔
674

675
func isDot(ch byte) bool {
20,009✔
676
        return ch == '.'
20,009✔
677
}
20,009✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc