• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 6458421232

09 Oct 2023 03:04PM UTC coverage: 89.499% (+0.2%) from 89.257%
6458421232

push

gh-ci

jeroiraz
test(embedded/tbtree): nodeRef coverage

Signed-off-by: Jeronimo Irazabal <jeronimo.irazabal@gmail.com>

33451 of 37376 relevant lines covered (89.5%)

144180.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.47
/embedded/sql/parser.go
1
/*
2
Copyright 2022 Codenotary Inc. All rights reserved.
3

4
Licensed under the Apache License, Version 2.0 (the "License");
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
        http://www.apache.org/licenses/LICENSE-2.0
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "encoding/hex"
22
        "errors"
23
        "fmt"
24
        "io"
25
        "strconv"
26
        "strings"
27
)
28

29
//go:generate go run golang.org/x/tools/cmd/goyacc -l -o sql_parser.go sql_grammar.y
30

31
var reservedWords = map[string]int{
32
        "CREATE":         CREATE,
33
        "DROP":           DROP,
34
        "USE":            USE,
35
        "DATABASE":       DATABASE,
36
        "SNAPSHOT":       SNAPSHOT,
37
        "HISTORY":        HISTORY,
38
        "OF":             OF,
39
        "SINCE":          SINCE,
40
        "AFTER":          AFTER,
41
        "BEFORE":         BEFORE,
42
        "UNTIL":          UNTIL,
43
        "TABLE":          TABLE,
44
        "PRIMARY":        PRIMARY,
45
        "KEY":            KEY,
46
        "UNIQUE":         UNIQUE,
47
        "INDEX":          INDEX,
48
        "ON":             ON,
49
        "ALTER":          ALTER,
50
        "ADD":            ADD,
51
        "RENAME":         RENAME,
52
        "TO":             TO,
53
        "COLUMN":         COLUMN,
54
        "INSERT":         INSERT,
55
        "CONFLICT":       CONFLICT,
56
        "DO":             DO,
57
        "NOTHING":        NOTHING,
58
        "UPSERT":         UPSERT,
59
        "INTO":           INTO,
60
        "VALUES":         VALUES,
61
        "UPDATE":         UPDATE,
62
        "SET":            SET,
63
        "DELETE":         DELETE,
64
        "BEGIN":          BEGIN,
65
        "TRANSACTION":    TRANSACTION,
66
        "COMMIT":         COMMIT,
67
        "ROLLBACK":       ROLLBACK,
68
        "SELECT":         SELECT,
69
        "DISTINCT":       DISTINCT,
70
        "FROM":           FROM,
71
        "UNION":          UNION,
72
        "ALL":            ALL,
73
        "TX":             TX,
74
        "JOIN":           JOIN,
75
        "HAVING":         HAVING,
76
        "WHERE":          WHERE,
77
        "GROUP":          GROUP,
78
        "BY":             BY,
79
        "LIMIT":          LIMIT,
80
        "OFFSET":         OFFSET,
81
        "ORDER":          ORDER,
82
        "AS":             AS,
83
        "ASC":            ASC,
84
        "DESC":           DESC,
85
        "NOT":            NOT,
86
        "LIKE":           LIKE,
87
        "EXISTS":         EXISTS,
88
        "IN":             IN,
89
        "AUTO_INCREMENT": AUTO_INCREMENT,
90
        "NULL":           NULL,
91
        "IF":             IF,
92
        "IS":             IS,
93
        "CAST":           CAST,
94
        "::":             SCAST,
95
}
96

97
var joinTypes = map[string]JoinType{
98
        "INNER": InnerJoin,
99
        "LEFT":  LeftJoin,
100
        "RIGHT": RightJoin,
101
}
102

103
var types = map[string]SQLValueType{
104
        "INTEGER":   IntegerType,
105
        "BOOLEAN":   BooleanType,
106
        "VARCHAR":   VarcharType,
107
        "UUID":      UUIDType,
108
        "BLOB":      BLOBType,
109
        "TIMESTAMP": TimestampType,
110
        "FLOAT":     Float64Type,
111
}
112

113
var aggregateFns = map[string]AggregateFn{
114
        "COUNT": COUNT,
115
        "SUM":   SUM,
116
        "MAX":   MAX,
117
        "MIN":   MIN,
118
        "AVG":   AVG,
119
}
120

121
var boolValues = map[string]bool{
122
        "TRUE":  true,
123
        "FALSE": false,
124
}
125

126
var cmpOps = map[string]CmpOperator{
127
        "=":  EQ,
128
        "!=": NE,
129
        "<>": NE,
130
        "<":  LT,
131
        "<=": LE,
132
        ">":  GT,
133
        ">=": GE,
134
}
135

136
var logicOps = map[string]LogicOperator{
137
        "AND": AND,
138
        "OR":  OR,
139
}
140

141
var ErrEitherNamedOrUnnamedParams = errors.New("either named or unnamed params")
142
var ErrEitherPosOrNonPosParams = errors.New("either positional or non-positional named params")
143
var ErrInvalidPositionalParameter = errors.New("invalid positional parameter")
144

145
type positionalParamType int
146

147
const (
148
        NamedNonPositionalParamType positionalParamType = iota + 1
149
        NamedPositionalParamType
150
        UnnamedParamType
151
)
152

153
type lexer struct {
154
        r               *aheadByteReader
155
        err             error
156
        namedParamsType positionalParamType
157
        paramsCount     int
158
        result          []SQLStmt
159
}
160

161
type aheadByteReader struct {
162
        nextChar  byte
163
        nextErr   error
164
        r         io.ByteReader
165
        readCount int
166
}
167

168
func newAheadByteReader(r io.ByteReader) *aheadByteReader {
1,732✔
169
        ar := &aheadByteReader{r: r}
1,732✔
170
        ar.nextChar, ar.nextErr = r.ReadByte()
1,732✔
171
        return ar
1,732✔
172
}
1,732✔
173

174
func (ar *aheadByteReader) ReadByte() (byte, error) {
106,337✔
175
        defer func() {
212,674✔
176
                if ar.nextErr == nil {
210,979✔
177
                        ar.nextChar, ar.nextErr = ar.r.ReadByte()
104,642✔
178
                }
104,642✔
179
        }()
180

181
        ar.readCount++
106,337✔
182

106,337✔
183
        return ar.nextChar, ar.nextErr
106,337✔
184
}
185

186
func (ar *aheadByteReader) ReadCount() int {
39✔
187
        return ar.readCount
39✔
188
}
39✔
189

190
func (ar *aheadByteReader) NextByte() (byte, error) {
81,673✔
191
        return ar.nextChar, ar.nextErr
81,673✔
192
}
81,673✔
193

194
func ParseString(sql string) ([]SQLStmt, error) {
98✔
195
        return Parse(strings.NewReader(sql))
98✔
196
}
98✔
197

198
func Parse(r io.ByteReader) ([]SQLStmt, error) {
1,732✔
199
        lexer := newLexer(r)
1,732✔
200

1,732✔
201
        yyParse(lexer)
1,732✔
202

1,732✔
203
        return lexer.result, lexer.err
1,732✔
204
}
1,732✔
205

206
func newLexer(r io.ByteReader) *lexer {
1,732✔
207
        return &lexer{
1,732✔
208
                r:   newAheadByteReader(r),
1,732✔
209
                err: nil,
1,732✔
210
        }
1,732✔
211
}
1,732✔
212

213
func (l *lexer) Lex(lval *yySymType) int {
24,270✔
214
        var ch byte
24,270✔
215
        var err error
24,270✔
216

24,270✔
217
        for {
63,429✔
218
                ch, err = l.r.ReadByte()
39,159✔
219
                if err == io.EOF {
40,854✔
220
                        return 0
1,695✔
221
                }
1,695✔
222
                if err != nil {
37,464✔
223
                        lval.err = err
×
224
                        return ERROR
×
225
                }
×
226

227
                if ch == '\t' {
40,130✔
228
                        continue
2,666✔
229
                }
230

231
                if ch == '/' && l.r.nextChar == '*' {
34,801✔
232
                        l.r.ReadByte()
3✔
233

3✔
234
                        for {
102✔
235
                                ch, err := l.r.ReadByte()
99✔
236
                                if err == io.EOF {
99✔
237
                                        break
×
238
                                }
239
                                if err != nil {
99✔
240
                                        lval.err = err
×
241
                                        return ERROR
×
242
                                }
×
243

244
                                if ch == '*' && l.r.nextChar == '/' {
102✔
245
                                        l.r.ReadByte() // consume closing slash
3✔
246
                                        break
3✔
247
                                }
248
                        }
249

250
                        continue
3✔
251
                }
252

253
                if isLineBreak(ch) {
35,592✔
254
                        if ch == '\r' && l.r.nextChar == '\n' {
798✔
255
                                l.r.ReadByte()
1✔
256
                        }
1✔
257
                        continue
797✔
258
                }
259

260
                if !isSpace(ch) {
56,573✔
261
                        break
22,575✔
262
                }
263
        }
264

265
        if isSeparator(ch) {
22,952✔
266
                return STMT_SEPARATOR
377✔
267
        }
377✔
268

269
        if isBLOBPrefix(ch) && isQuote(l.r.nextChar) {
22,291✔
270
                l.r.ReadByte() // consume starting quote
93✔
271

93✔
272
                tail, err := l.readString()
93✔
273
                if err != nil {
93✔
274
                        lval.err = err
×
275
                        return ERROR
×
276
                }
×
277

278
                val, err := hex.DecodeString(tail)
93✔
279
                if err != nil {
93✔
280
                        lval.err = err
×
281
                        return ERROR
×
282
                }
×
283

284
                lval.blob = val
93✔
285
                return BLOB
93✔
286
        }
287

288
        if isLetter(ch) {
34,586✔
289
                tail, err := l.readWord()
12,481✔
290
                if err != nil {
12,481✔
291
                        lval.err = err
×
292
                        return ERROR
×
293
                }
×
294

295
                w := fmt.Sprintf("%c%s", ch, tail)
12,481✔
296
                tid := strings.ToUpper(w)
12,481✔
297

12,481✔
298
                sqlType, ok := types[tid]
12,481✔
299
                if ok {
13,067✔
300
                        lval.sqlType = sqlType
586✔
301
                        return TYPE
586✔
302
                }
586✔
303

304
                val, ok := boolValues[tid]
11,895✔
305
                if ok {
12,079✔
306
                        lval.boolean = val
184✔
307
                        return BOOLEAN
184✔
308
                }
184✔
309

310
                lop, ok := logicOps[tid]
11,711✔
311
                if ok {
11,827✔
312
                        lval.logicOp = lop
116✔
313
                        return LOP
116✔
314
                }
116✔
315

316
                afn, ok := aggregateFns[tid]
11,595✔
317
                if ok {
11,681✔
318
                        lval.aggFn = afn
86✔
319
                        return AGGREGATE_FUNC
86✔
320
                }
86✔
321

322
                join, ok := joinTypes[tid]
11,509✔
323
                if ok {
11,526✔
324
                        lval.joinType = join
17✔
325
                        return JOINTYPE
17✔
326
                }
17✔
327

328
                tkn, ok := reservedWords[tid]
11,492✔
329
                if ok {
17,536✔
330
                        return tkn
6,044✔
331
                }
6,044✔
332

333
                lval.id = strings.ToLower(w)
5,448✔
334

5,448✔
335
                return IDENTIFIER
5,448✔
336
        }
337

338
        if isDoubleQuote(ch) {
9,631✔
339
                tail, err := l.readWord()
7✔
340
                if err != nil {
7✔
341
                        lval.err = err
×
342
                        return ERROR
×
343
                }
×
344

345
                if !isDoubleQuote(l.r.nextChar) {
8✔
346
                        lval.err = fmt.Errorf("double quote expected")
1✔
347
                        return ERROR
1✔
348
                }
1✔
349

350
                l.r.ReadByte() // consume ending quote
6✔
351

6✔
352
                lval.id = strings.ToLower(tail)
6✔
353
                return IDENTIFIER
6✔
354
        }
355

356
        if isNumber(ch) {
10,331✔
357
                tail, err := l.readNumber()
714✔
358
                if err != nil {
714✔
359
                        lval.err = err
×
360
                        return ERROR
×
361
                }
×
362
                // looking for a float
363
                if isDot(l.r.nextChar) {
738✔
364
                        l.r.ReadByte() // consume dot
24✔
365

24✔
366
                        decimalPart, err := l.readNumber()
24✔
367
                        if err != nil {
24✔
368
                                lval.err = err
×
369
                                return ERROR
×
370
                        }
×
371

372
                        val, err := strconv.ParseFloat(fmt.Sprintf("%c%s.%s", ch, tail, decimalPart), 64)
24✔
373
                        if err != nil {
25✔
374
                                lval.err = err
1✔
375
                                return ERROR
1✔
376
                        }
1✔
377

378
                        lval.float = val
23✔
379
                        return FLOAT
23✔
380
                }
381

382
                val, err := strconv.ParseUint(fmt.Sprintf("%c%s", ch, tail), 10, 64)
690✔
383
                if err != nil {
690✔
384
                        lval.err = err
×
385
                        return ERROR
×
386
                }
×
387

388
                lval.integer = val
690✔
389
                return INTEGER
690✔
390
        }
391

392
        if isComparison(ch) {
9,200✔
393
                tail, err := l.readComparison()
297✔
394
                if err != nil {
297✔
395
                        lval.err = err
×
396
                        return ERROR
×
397
                }
×
398

399
                op := fmt.Sprintf("%c%s", ch, tail)
297✔
400

297✔
401
                cmpOp, ok := cmpOps[op]
297✔
402
                if !ok {
297✔
403
                        lval.err = fmt.Errorf("invalid comparison operator %s", op)
×
404
                        return ERROR
×
405
                }
×
406

407
                lval.cmpOp = cmpOp
297✔
408
                return CMPOP
297✔
409
        }
410

411
        if isQuote(ch) {
8,971✔
412
                tail, err := l.readString()
365✔
413
                if err != nil {
365✔
414
                        lval.err = err
×
415
                        return ERROR
×
416
                }
×
417

418
                lval.str = tail
365✔
419
                return VARCHAR
365✔
420
        }
421

422
        if ch == ':' {
8,243✔
423
                ch, err := l.r.ReadByte()
2✔
424
                if err != nil {
2✔
425
                        lval.err = err
×
426
                        return ERROR
×
427
                }
×
428

429
                if ch != ':' {
2✔
430
                        lval.err = fmt.Errorf("colon expected")
×
431
                        return ERROR
×
432
                }
×
433

434
                return SCAST
2✔
435
        }
436

437
        if ch == '@' {
8,821✔
438
                if l.namedParamsType == UnnamedParamType {
583✔
439
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
440
                        return ERROR
1✔
441
                }
1✔
442

443
                if l.namedParamsType == NamedPositionalParamType {
582✔
444
                        lval.err = ErrEitherPosOrNonPosParams
1✔
445
                        return ERROR
1✔
446
                }
1✔
447

448
                l.namedParamsType = NamedNonPositionalParamType
580✔
449

580✔
450
                ch, err := l.r.NextByte()
580✔
451
                if err != nil {
580✔
452
                        lval.err = err
×
453
                        return ERROR
×
454
                }
×
455

456
                if !isLetter(ch) {
580✔
457
                        return ERROR
×
458
                }
×
459

460
                id, err := l.readWord()
580✔
461
                if err != nil {
580✔
462
                        lval.err = err
×
463
                        return ERROR
×
464
                }
×
465

466
                lval.id = strings.ToLower(id)
580✔
467

580✔
468
                return NPARAM
580✔
469
        }
470

471
        if ch == '$' {
7,699✔
472
                if l.namedParamsType == UnnamedParamType {
43✔
473
                        lval.err = ErrEitherNamedOrUnnamedParams
1✔
474
                        return ERROR
1✔
475
                }
1✔
476

477
                if l.namedParamsType == NamedNonPositionalParamType {
42✔
478
                        lval.err = ErrEitherPosOrNonPosParams
1✔
479
                        return ERROR
1✔
480
                }
1✔
481

482
                id, err := l.readNumber()
40✔
483
                if err != nil {
40✔
484
                        lval.err = err
×
485
                        return ERROR
×
486
                }
×
487

488
                pid, err := strconv.Atoi(id)
40✔
489
                if err != nil {
41✔
490
                        lval.err = err
1✔
491
                        return ERROR
1✔
492
                }
1✔
493

494
                if pid < 1 {
40✔
495
                        lval.err = ErrInvalidPositionalParameter
1✔
496
                        return ERROR
1✔
497
                }
1✔
498

499
                lval.pparam = pid
38✔
500

38✔
501
                l.namedParamsType = NamedPositionalParamType
38✔
502

38✔
503
                return PPARAM
38✔
504
        }
505

506
        if ch == '?' {
7,752✔
507
                if l.namedParamsType == NamedNonPositionalParamType || l.namedParamsType == NamedPositionalParamType {
139✔
508
                        lval.err = ErrEitherNamedOrUnnamedParams
2✔
509
                        return ERROR
2✔
510
                }
2✔
511

512
                l.paramsCount++
135✔
513
                lval.pparam = l.paramsCount
135✔
514

135✔
515
                l.namedParamsType = UnnamedParamType
135✔
516

135✔
517
                return PPARAM
135✔
518
        }
519

520
        if isDot(ch) {
7,576✔
521
                if isNumber(l.r.nextChar) { // looking for  a float
103✔
522
                        decimalPart, err := l.readNumber()
5✔
523
                        if err != nil {
5✔
524
                                lval.err = err
×
525
                                return ERROR
×
526
                        }
×
527
                        val, err := strconv.ParseFloat(fmt.Sprintf("%d.%s", 0, decimalPart), 64)
5✔
528
                        if err != nil {
5✔
529
                                lval.err = err
×
530
                                return ERROR
×
531
                        }
×
532
                        lval.float = val
5✔
533
                        return FLOAT
5✔
534
                }
535
                return DOT
93✔
536
        }
537

538
        return int(ch)
7,380✔
539
}
540

541
func (l *lexer) Error(err string) {
39✔
542
        l.err = fmt.Errorf("%s at position %d", err, l.r.ReadCount())
39✔
543
}
39✔
544

545
func (l *lexer) readWord() (string, error) {
13,068✔
546
        return l.readWhile(func(ch byte) bool {
78,830✔
547
                return isLetter(ch) || isNumber(ch)
65,762✔
548
        })
65,762✔
549
}
550

551
func (l *lexer) readNumber() (string, error) {
783✔
552
        return l.readWhile(isNumber)
783✔
553
}
783✔
554

555
func (l *lexer) readString() (string, error) {
458✔
556
        var b bytes.Buffer
458✔
557

458✔
558
        for {
3,914✔
559
                ch, err := l.r.ReadByte()
3,456✔
560
                if err != nil {
3,456✔
561
                        return "", err
×
562
                }
×
563

564
                nextCh, _ := l.r.NextByte()
3,456✔
565

3,456✔
566
                if isQuote(ch) {
3,916✔
567
                        if isQuote(nextCh) {
462✔
568
                                l.r.ReadByte() // consume escaped quote
2✔
569
                        } else {
460✔
570
                                break // string completely read
458✔
571
                        }
572
                }
573

574
                b.WriteByte(ch)
2,998✔
575
        }
576

577
        return b.String(), nil
458✔
578
}
579

580
func (l *lexer) readComparison() (string, error) {
297✔
581
        return l.readWhile(func(ch byte) bool {
646✔
582
                return isComparison(ch)
349✔
583
        })
349✔
584
}
585

586
func (l *lexer) readWhile(condFn func(b byte) bool) (string, error) {
14,148✔
587
        var b bytes.Buffer
14,148✔
588

14,148✔
589
        for {
91,785✔
590
                ch, err := l.r.NextByte()
77,637✔
591
                if err == io.EOF {
78,015✔
592
                        break
378✔
593
                }
594
                if err != nil {
77,259✔
595
                        return "", err
×
596
                }
×
597

598
                if !condFn(ch) {
91,029✔
599
                        break
13,770✔
600
                }
601

602
                ch, _ = l.r.ReadByte()
63,489✔
603
                b.WriteByte(ch)
63,489✔
604
        }
605

606
        return b.String(), nil
14,148✔
607
}
608

609
func isBLOBPrefix(ch byte) bool {
22,198✔
610
        return ch == 'x'
22,198✔
611
}
22,198✔
612

613
func isSeparator(ch byte) bool {
22,575✔
614
        return ch == ';'
22,575✔
615
}
22,575✔
616

617
func isLineBreak(ch byte) bool {
34,795✔
618
        return ch == '\r' || ch == '\n'
34,795✔
619
}
34,795✔
620

621
func isSpace(ch byte) bool {
33,998✔
622
        return ch == 32 || ch == 9 //SPACE or TAB
33,998✔
623
}
33,998✔
624

625
func isNumber(ch byte) bool {
35,423✔
626
        return '0' <= ch && ch <= '9'
35,423✔
627
}
35,423✔
628

629
func isLetter(ch byte) bool {
88,447✔
630
        return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_'
88,447✔
631
}
88,447✔
632

633
func isComparison(ch byte) bool {
9,252✔
634
        return ch == '!' || ch == '<' || ch == '=' || ch == '>'
9,252✔
635
}
9,252✔
636

637
func isQuote(ch byte) bool {
12,618✔
638
        return ch == 0x27
12,618✔
639
}
12,618✔
640

641
func isDoubleQuote(ch byte) bool {
9,631✔
642
        return ch == 0x22
9,631✔
643
}
9,631✔
644

645
func isDot(ch byte) bool {
8,192✔
646
        return ch == '.'
8,192✔
647
}
8,192✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc