• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 12258730221

10 Dec 2024 02:47PM UTC coverage: 89.138% (-0.1%) from 89.266%
12258730221

Pull #2036

gh-ci

ostafen
chore(embedded/sql): Add support for core pg_catalog tables (pg_class, pg_namespace, pg_roles)

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>
Pull Request #2036: chore(embedded/sql): Add support for core pg_catalog tables (pg_class…

101 of 183 new or added lines in 13 files covered. (55.19%)

1 existing line in 1 file now uncovered.

37586 of 42166 relevant lines covered (89.14%)

150871.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.87
/embedded/sql/parser.go
1
/*
2
Copyright 2024 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "encoding/hex"
22
        "errors"
23
        "fmt"
24
        "io"
25
        "strconv"
26
        "strings"
27
)
28

29
//go:generate go run golang.org/x/tools/cmd/goyacc -l -o sql_parser.go sql_grammar.y
30

31
var reservedWords = map[string]int{
32
        "CREATE":         CREATE,
33
        "DROP":           DROP,
34
        "USE":            USE,
35
        "DATABASE":       DATABASE,
36
        "SNAPSHOT":       SNAPSHOT,
37
        "HISTORY":        HISTORY,
38
        "OF":             OF,
39
        "SINCE":          SINCE,
40
        "AFTER":          AFTER,
41
        "BEFORE":         BEFORE,
42
        "UNTIL":          UNTIL,
43
        "TABLE":          TABLE,
44
        "PRIMARY":        PRIMARY,
45
        "KEY":            KEY,
46
        "UNIQUE":         UNIQUE,
47
        "INDEX":          INDEX,
48
        "ON":             ON,
49
        "ALTER":          ALTER,
50
        "ADD":            ADD,
51
        "RENAME":         RENAME,
52
        "TO":             TO,
53
        "COLUMN":         COLUMN,
54
        "INSERT":         INSERT,
55
        "CONFLICT":       CONFLICT,
56
        "DO":             DO,
57
        "NOTHING":        NOTHING,
58
        "RETURNING":      RETURNING,
59
        "UPSERT":         UPSERT,
60
        "INTO":           INTO,
61
        "VALUES":         VALUES,
62
        "UPDATE":         UPDATE,
63
        "SET":            SET,
64
        "DELETE":         DELETE,
65
        "BEGIN":          BEGIN,
66
        "TRANSACTION":    TRANSACTION,
67
        "COMMIT":         COMMIT,
68
        "ROLLBACK":       ROLLBACK,
69
        "SELECT":         SELECT,
70
        "DISTINCT":       DISTINCT,
71
        "FROM":           FROM,
72
        "UNION":          UNION,
73
        "ALL":            ALL,
74
        "TX":             TX,
75
        "JOIN":           JOIN,
76
        "HAVING":         HAVING,
77
        "WHERE":          WHERE,
78
        "GROUP":          GROUP,
79
        "BY":             BY,
80
        "LIMIT":          LIMIT,
81
        "OFFSET":         OFFSET,
82
        "ORDER":          ORDER,
83
        "AS":             AS,
84
        "ASC":            ASC,
85
        "DESC":           DESC,
86
        "AND":            AND,
87
        "OR":             OR,
88
        "NOT":            NOT,
89
        "LIKE":           LIKE,
90
        "EXISTS":         EXISTS,
91
        "IN":             IN,
92
        "AUTO_INCREMENT": AUTO_INCREMENT,
93
        "NULL":           NULL,
94
        "IF":             IF,
95
        "IS":             IS,
96
        "CAST":           CAST,
97
        "::":             SCAST,
98
        "SHOW":           SHOW,
99
        "DATABASES":      DATABASES,
100
        "TABLES":         TABLES,
101
        "USERS":          USERS,
102
        "USER":           USER,
103
        "WITH":           WITH,
104
        "PASSWORD":       PASSWORD,
105
        "READ":           READ,
106
        "READWRITE":      READWRITE,
107
        "ADMIN":          ADMIN,
108
        "GRANT":          GRANT,
109
        "REVOKE":         REVOKE,
110
        "GRANTS":         GRANTS,
111
        "FOR":            FOR,
112
        "PRIVILEGES":     PRIVILEGES,
113
        "CHECK":          CHECK,
114
        "CONSTRAINT":     CONSTRAINT,
115
        "CASE":           CASE,
116
        "WHEN":           WHEN,
117
        "THEN":           THEN,
118
        "ELSE":           ELSE,
119
        "END":            END,
120
}
121

122
var joinTypes = map[string]JoinType{
123
        "INNER": InnerJoin,
124
        "LEFT":  LeftJoin,
125
        "RIGHT": RightJoin,
126
}
127

128
var types = map[string]SQLValueType{
129
        "INTEGER":   IntegerType,
130
        "BOOLEAN":   BooleanType,
131
        "VARCHAR":   VarcharType,
132
        "UUID":      UUIDType,
133
        "BLOB":      BLOBType,
134
        "TIMESTAMP": TimestampType,
135
        "FLOAT":     Float64Type,
136
        "JSON":      JSONType,
137
}
138

139
var aggregateFns = map[string]AggregateFn{
140
        "COUNT": COUNT,
141
        "SUM":   SUM,
142
        "MAX":   MAX,
143
        "MIN":   MIN,
144
        "AVG":   AVG,
145
}
146

147
var boolValues = map[string]bool{
148
        "TRUE":  true,
149
        "FALSE": false,
150
}
151

152
var cmpOps = map[string]CmpOperator{
153
        "=":  EQ,
154
        "!=": NE,
155
        "<>": NE,
156
        "<":  LT,
157
        "<=": LE,
158
        ">":  GT,
159
        ">=": GE,
160
}
161

162
var ErrEitherNamedOrUnnamedParams = errors.New("either named or unnamed params")
163
var ErrEitherPosOrNonPosParams = errors.New("either positional or non-positional named params")
164
var ErrInvalidPositionalParameter = errors.New("invalid positional parameter")
165

166
type positionalParamType int
167

168
const (
169
        NamedNonPositionalParamType positionalParamType = iota + 1
170
        NamedPositionalParamType
171
        UnnamedParamType
172
)
173

174
type lexer struct {
175
        r               *aheadByteReader
176
        err             error
177
        namedParamsType positionalParamType
178
        paramsCount     int
179
        result          []SQLStmt
180
}
181

182
type aheadByteReader struct {
183
        nextChar  byte
184
        nextErr   error
185
        r         io.ByteReader
186
        readCount int
187
}
188

189
func newAheadByteReader(r io.ByteReader) *aheadByteReader {
3,444✔
190
        ar := &aheadByteReader{r: r}
3,444✔
191
        ar.nextChar, ar.nextErr = r.ReadByte()
3,444✔
192
        return ar
3,444✔
193
}
3,444✔
194

195
func (ar *aheadByteReader) ReadByte() (byte, error) {
279,841✔
196
        defer func() {
559,682✔
197
                if ar.nextErr == nil {
556,374✔
198
                        ar.nextChar, ar.nextErr = ar.r.ReadByte()
276,533✔
199
                }
276,533✔
200
        }()
201

202
        ar.readCount++
279,841✔
203

279,841✔
204
        return ar.nextChar, ar.nextErr
279,841✔
205
}
206

207
func (ar *aheadByteReader) ReadCount() int {
41✔
208
        return ar.readCount
41✔
209
}
41✔
210

211
func (ar *aheadByteReader) NextByte() (byte, error) {
222,254✔
212
        return ar.nextChar, ar.nextErr
222,254✔
213
}
222,254✔
214

215
func ParseSQLString(sql string) ([]SQLStmt, error) {
311✔
216
        return ParseSQL(strings.NewReader(sql))
311✔
217
}
311✔
218

219
func ParseSQL(r io.ByteReader) ([]SQLStmt, error) {
3,444✔
220
        lexer := newLexer(r)
3,444✔
221

3,444✔
222
        yyParse(lexer)
3,444✔
223

3,444✔
224
        return lexer.result, lexer.err
3,444✔
225
}
3,444✔
226

227
func ParseExpFromString(exp string) (ValueExp, error) {
199✔
228
        stmt := fmt.Sprintf("SELECT * FROM t WHERE %s", exp)
199✔
229

199✔
230
        res, err := ParseSQLString(stmt)
199✔
231
        if err != nil {
199✔
232
                return nil, err
×
233
        }
×
234

235
        s := res[0].(*SelectStmt)
199✔
236
        return s.where, nil
199✔
237
}
238

239
func newLexer(r io.ByteReader) *lexer {
3,444✔
240
        return &lexer{
3,444✔
241
                r:   newAheadByteReader(r),
3,444✔
242
                err: nil,
3,444✔
243
        }
3,444✔
244
}
3,444✔
245

246
func (l *lexer) Lex(lval *yySymType) int {
62,148✔
247
        var ch byte
62,148✔
248
        var err error
62,148✔
249

62,148✔
250
        for {
159,558✔
251
                ch, err = l.r.ReadByte()
97,410✔
252
                if err == io.EOF {
100,718✔
253
                        return 0
3,308✔
254
                }
3,308✔
255
                if err != nil {
94,102✔
256
                        lval.err = err
×
257
                        return ERROR
×
258
                }
×
259

260
                if ch == '\t' {
98,517✔
261
                        continue
4,415✔
262
                }
263

264
                if ch == '/' && l.r.nextChar == '*' {
89,690✔
265
                        l.r.ReadByte()
3✔
266

3✔
267
                        for {
102✔
268
                                ch, err := l.r.ReadByte()
99✔
269
                                if err == io.EOF {
99✔
270
                                        break
×
271
                                }
272
                                if err != nil {
99✔
273
                                        lval.err = err
×
274
                                        return ERROR
×
275
                                }
×
276

277
                                if ch == '*' && l.r.nextChar == '/' {
102✔
278
                                        l.r.ReadByte() // consume closing slash
3✔
279
                                        break
3✔
280
                                }
281
                        }
282

283
                        continue
3✔
284
                }
285

286
                if isLineBreak(ch) {
90,990✔
287
                        if ch == '\r' && l.r.nextChar == '\n' {
1,307✔
288
                                l.r.ReadByte()
1✔
289
                        }
1✔
290
                        continue
1,306✔
291
                }
292

293
                if !isSpace(ch) {
147,218✔
294
                        break
58,840✔
295
                }
296
        }
297

298
        if isSeparator(ch) {
59,255✔
299
                return STMT_SEPARATOR
415✔
300
        }
415✔
301

302
        if ch == '-' && l.r.nextChar == '>' {
58,483✔
303
                l.r.ReadByte()
58✔
304
                return ARROW
58✔
305
        }
58✔
306

307
        if isBLOBPrefix(ch) && isQuote(l.r.nextChar) {
58,461✔
308
                l.r.ReadByte() // consume starting quote
94✔
309

94✔
310
                tail, err := l.readString()
94✔
311
                if err != nil {
94✔
312
                        lval.err = err
×
313
                        return ERROR
×
314
                }
×
315

316
                val, err := hex.DecodeString(tail)
94✔
317
                if err != nil {
94✔
318
                        lval.err = err
×
319
                        return ERROR
×
320
                }
×
321

322
                lval.blob = val
94✔
323
                return BLOB
94✔
324
        }
325

326
        if isLetter(ch) {
84,949✔
327
                tail, err := l.readWord()
26,676✔
328
                if err != nil {
26,676✔
329
                        lval.err = err
×
330
                        return ERROR
×
331
                }
×
332

333
                w := fmt.Sprintf("%c%s", ch, tail)
26,676✔
334
                tid := strings.ToUpper(w)
26,676✔
335

26,676✔
336
                sqlType, ok := types[tid]
26,676✔
337
                if ok {
27,359✔
338
                        lval.sqlType = sqlType
683✔
339
                        return TYPE
683✔
340
                }
683✔
341

342
                val, ok := boolValues[tid]
25,993✔
343
                if ok {
26,232✔
344
                        lval.boolean = val
239✔
345
                        return BOOLEAN
239✔
346
                }
239✔
347

348
                afn, ok := aggregateFns[tid]
25,754✔
349
                if ok {
25,887✔
350
                        lval.aggFn = afn
133✔
351
                        return AGGREGATE_FUNC
133✔
352
                }
133✔
353

354
                join, ok := joinTypes[tid]
25,621✔
355
                if ok {
25,642✔
356
                        lval.joinType = join
21✔
357
                        return JOINTYPE
21✔
358
                }
21✔
359

360
                tkn, ok := reservedWords[tid]
25,600✔
361
                if ok {
37,520✔
362
                        return tkn
11,920✔
363
                }
11,920✔
364

365
                lval.id = strings.ToLower(w)
13,680✔
366

13,680✔
367
                return IDENTIFIER
13,680✔
368
        }
369

370
        if isDoubleQuote(ch) {
31,604✔
371
                tail, err := l.readWord()
7✔
372
                if err != nil {
7✔
373
                        lval.err = err
×
374
                        return ERROR
×
375
                }
×
376

377
                if !isDoubleQuote(l.r.nextChar) {
8✔
378
                        lval.err = fmt.Errorf("double quote expected")
1✔
379
                        return ERROR
1✔
380
                }
1✔
381

382
                l.r.ReadByte() // consume ending quote
6✔
383

6✔
384
                lval.id = strings.ToLower(tail)
6✔
385
                return IDENTIFIER
6✔
386
        }
387

388
        if isNumber(ch) {
32,626✔
389
                tail, err := l.readNumber()
1,036✔
390
                if err != nil {
1,036✔
391
                        lval.err = err
×
392
                        return ERROR
×
393
                }
×
394
                // looking for a float
395
                if isDot(l.r.nextChar) {
1,099✔
396
                        l.r.ReadByte() // consume dot
63✔
397

63✔
398
                        decimalPart, err := l.readNumber()
63✔
399
                        if err != nil {
63✔
400
                                lval.err = err
×
401
                                return ERROR
×
402
                        }
×
403

404
                        val, err := strconv.ParseFloat(fmt.Sprintf("%c%s.%s", ch, tail, decimalPart), 64)
63✔
405
                        if err != nil {
64✔
406
                                lval.err = err
1✔
407
                                return ERROR
1✔
408
                        }
1✔
409

410
                        lval.float = val
62✔
411
                        return FLOAT
62✔
412
                }
413

414
                val, err := strconv.ParseUint(fmt.Sprintf("%c%s", ch, tail), 10, 64)
973✔
415
                if err != nil {
973✔
416
                        lval.err = err
×
417
                        return ERROR
×
418
                }
×
419

420
                lval.integer = val
973✔
421
                return INTEGER
973✔
422
        }
423

424
        if isComparison(ch) {
31,025✔
425
                tail, err := l.readComparison()
471✔
426
                if err != nil {
471✔
427
                        lval.err = err
×
428
                        return ERROR
×
429
                }
×
430

431
                op := fmt.Sprintf("%c%s", ch, tail)
471✔
432
                if op == "!~" {
471✔
NEW
433
                        return NOT_MATCHES_OP
×
NEW
434
                }
×
435

436
                cmpOp, ok := cmpOps[op]
471✔
437
                if !ok {
471✔
438
                        lval.err = fmt.Errorf("invalid comparison operator %s", op)
×
439
                        return ERROR
×
440
                }
×
441

442
                lval.cmpOp = cmpOp
471✔
443
                return CMPOP
471✔
444
        }
445

446
        if isQuote(ch) {
30,872✔
447
                tail, err := l.readString()
789✔
448
                if err != nil {
789✔
449
                        lval.err = err
×
450
                        return ERROR
×
451
                }
×
452

453
                lval.str = tail
789✔
454
                return VARCHAR
789✔
455
        }
456

457
        if ch == ':' {
29,316✔
458
                ch, err := l.r.ReadByte()
22✔
459
                if err != nil {
22✔
460
                        lval.err = err
×
461
                        return ERROR
×
462
                }
×
463

464
                if ch != ':' {
22✔
465
                        lval.err = fmt.Errorf("colon expected")
×
466
                        return ERROR
×
467
                }
×
468

469
                return SCAST
22✔
470
        }
471

472
        if ch == '@' {
35,212✔
473
                if l.namedParamsType == UnnamedParamType {
5,941✔
474
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
475
                        return ERROR
1✔
476
                }
1✔
477

478
                if l.namedParamsType == NamedPositionalParamType {
5,940✔
479
                        lval.err = ErrEitherPosOrNonPosParams
1✔
480
                        return ERROR
1✔
481
                }
1✔
482

483
                l.namedParamsType = NamedNonPositionalParamType
5,938✔
484

5,938✔
485
                ch, err := l.r.NextByte()
5,938✔
486
                if err != nil {
5,938✔
487
                        lval.err = err
×
488
                        return ERROR
×
489
                }
×
490

491
                if !isLetter(ch) {
5,938✔
492
                        return ERROR
×
493
                }
×
494

495
                id, err := l.readWord()
5,938✔
496
                if err != nil {
5,938✔
497
                        lval.err = err
×
498
                        return ERROR
×
499
                }
×
500

501
                lval.id = strings.ToLower(id)
5,938✔
502

5,938✔
503
                return NPARAM
5,938✔
504
        }
505

506
        if ch == '$' {
23,374✔
507
                if l.namedParamsType == UnnamedParamType {
43✔
508
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
509
                        return ERROR
1✔
510
                }
1✔
511

512
                if l.namedParamsType == NamedNonPositionalParamType {
42✔
513
                        lval.err = ErrEitherPosOrNonPosParams
1✔
514
                        return ERROR
1✔
515
                }
1✔
516

517
                id, err := l.readNumber()
40✔
518
                if err != nil {
40✔
519
                        lval.err = err
×
520
                        return ERROR
×
521
                }
×
522

523
                pid, err := strconv.Atoi(id)
40✔
524
                if err != nil {
41✔
525
                        lval.err = err
1✔
526
                        return ERROR
1✔
527
                }
1✔
528

529
                if pid < 1 {
40✔
530
                        lval.err = ErrInvalidPositionalParameter
1✔
531
                        return ERROR
1✔
532
                }
1✔
533

534
                lval.pparam = pid
38✔
535

38✔
536
                l.namedParamsType = NamedPositionalParamType
38✔
537

38✔
538
                return PPARAM
38✔
539
        }
540

541
        if ch == '?' {
23,429✔
542
                if l.namedParamsType == NamedNonPositionalParamType || l.namedParamsType == NamedPositionalParamType {
141✔
543
                        lval.err = ErrEitherNamedOrUnnamedParams
2✔
544
                        return ERROR
2✔
545
                }
2✔
546

547
                l.paramsCount++
137✔
548
                lval.pparam = l.paramsCount
137✔
549

137✔
550
                l.namedParamsType = UnnamedParamType
137✔
551

137✔
552
                return PPARAM
137✔
553
        }
554

555
        if isDot(ch) {
23,304✔
556
                if isNumber(l.r.nextChar) { // looking for  a float
158✔
557
                        decimalPart, err := l.readNumber()
5✔
558
                        if err != nil {
5✔
559
                                lval.err = err
×
560
                                return ERROR
×
561
                        }
×
562
                        val, err := strconv.ParseFloat(fmt.Sprintf("%d.%s", 0, decimalPart), 64)
5✔
563
                        if err != nil {
5✔
564
                                lval.err = err
×
565
                                return ERROR
×
566
                        }
×
567
                        lval.float = val
5✔
568
                        return FLOAT
5✔
569
                }
570
                return DOT
148✔
571
        }
572

573
        return int(ch)
22,998✔
574
}
575

576
func (l *lexer) Error(err string) {
41✔
577
        l.err = fmt.Errorf("%s at position %d", err, l.r.ReadCount())
41✔
578
}
41✔
579

580
func (l *lexer) readWord() (string, error) {
32,621✔
581
        return l.readWhile(func(ch byte) bool {
210,867✔
582
                return isLetter(ch) || isNumber(ch)
178,246✔
583
        })
178,246✔
584
}
585

586
func (l *lexer) readNumber() (string, error) {
1,144✔
587
        return l.readWhile(isNumber)
1,144✔
588
}
1,144✔
589

590
func (l *lexer) readString() (string, error) {
883✔
591
        var b bytes.Buffer
883✔
592

883✔
593
        for {
26,069✔
594
                ch, err := l.r.ReadByte()
25,186✔
595
                if err != nil {
25,186✔
596
                        return "", err
×
597
                }
×
598

599
                nextCh, _ := l.r.NextByte()
25,186✔
600

25,186✔
601
                if isQuote(ch) {
26,071✔
602
                        if isQuote(nextCh) {
887✔
603
                                l.r.ReadByte() // consume escaped quote
2✔
604
                        } else {
885✔
605
                                break // string completely read
883✔
606
                        }
607
                }
608

609
                b.WriteByte(ch)
24,303✔
610
        }
611

612
        return b.String(), nil
883✔
613
}
614

615
func (l *lexer) readComparison() (string, error) {
471✔
616
        return l.readWhile(func(ch byte) bool {
1,080✔
617
                return isComparison(ch)
609✔
618
        })
609✔
619
}
620

621
func (l *lexer) readWhile(condFn func(b byte) bool) (string, error) {
34,236✔
622
        var b bytes.Buffer
34,236✔
623

34,236✔
624
        for {
225,366✔
625
                ch, err := l.r.NextByte()
191,130✔
626
                if err == io.EOF {
191,741✔
627
                        break
611✔
628
                }
629
                if err != nil {
190,519✔
630
                        return "", err
×
631
                }
×
632

633
                if !condFn(ch) {
224,144✔
634
                        break
33,625✔
635
                }
636

637
                ch, _ = l.r.ReadByte()
156,894✔
638
                b.WriteByte(ch)
156,894✔
639
        }
640

641
        return b.String(), nil
34,236✔
642
}
643

644
func isBLOBPrefix(ch byte) bool {
58,367✔
645
        return ch == 'x'
58,367✔
646
}
58,367✔
647

648
func isSeparator(ch byte) bool {
58,840✔
649
        return ch == ';'
58,840✔
650
}
58,840✔
651

652
func isLineBreak(ch byte) bool {
89,684✔
653
        return ch == '\r' || ch == '\n'
89,684✔
654
}
89,684✔
655

656
func isSpace(ch byte) bool {
88,378✔
657
        return ch == 32 || ch == 9 //SPACE or TAB
88,378✔
658
}
88,378✔
659

660
func isNumber(ch byte) bool {
78,461✔
661
        return '0' <= ch && ch <= '9'
78,461✔
662
}
78,461✔
663

664
func isLetter(ch byte) bool {
242,457✔
665
        return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_'
242,457✔
666
}
242,457✔
667

668
func isComparison(ch byte) bool {
31,163✔
669
        return ch == '!' || ch == '<' || ch == '=' || ch == '>' || ch == '~'
31,163✔
670
}
31,163✔
671

672
func isQuote(ch byte) bool {
56,251✔
673
        return ch == 0x27
56,251✔
674
}
56,251✔
675

676
func isDoubleQuote(ch byte) bool {
31,604✔
677
        return ch == 0x22
31,604✔
678
}
31,604✔
679

680
func isDot(ch byte) bool {
24,187✔
681
        return ch == '.'
24,187✔
682
}
24,187✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc