• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 9791636125

04 Jul 2024 09:10AM UTC coverage: 89.523% (+0.1%) from 89.416%
9791636125

push

gh-ci

ostafen
Implement CHECK constraints

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>

554 of 628 new or added lines in 10 files covered. (88.22%)

14 existing lines in 7 files now uncovered.

35374 of 39514 relevant lines covered (89.52%)

160396.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.79
/embedded/sql/stmt.go
1
/*
2
Copyright 2024 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix       = "CTL."
38
        catalogTablePrefix  = "CTL.TABLE."  // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix = "CTL.COLUMN." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogCheckPrefix  = "CTL.CHECK."  // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
41
        catalogIndexPrefix  = "CTL.INDEX."  // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
42

43
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
44
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
45
)
46

47
const (
48
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
49
        PKIndexID  = uint32(0)
50
)
51

52
const (
53
        nullableFlag      byte = 1 << iota
54
        autoIncrementFlag byte = 1 << iota
55
)
56

57
const (
58
        revCol        = "_rev"
59
        txMetadataCol = "_tx_metadata"
60
)
61

62
var reservedColumns = map[string]struct{}{
63
        revCol:        {},
64
        txMetadataCol: {},
65
}
66

67
func isReservedCol(col string) bool {
12,514✔
68
        _, ok := reservedColumns[col]
12,514✔
69
        return ok
12,514✔
70
}
12,514✔
71

72
type SQLValueType = string
73

74
const (
75
        IntegerType   SQLValueType = "INTEGER"
76
        BooleanType   SQLValueType = "BOOLEAN"
77
        VarcharType   SQLValueType = "VARCHAR"
78
        UUIDType      SQLValueType = "UUID"
79
        BLOBType      SQLValueType = "BLOB"
80
        Float64Type   SQLValueType = "FLOAT"
81
        TimestampType SQLValueType = "TIMESTAMP"
82
        AnyType       SQLValueType = "ANY"
83
        JSONType      SQLValueType = "JSON"
84
)
85

86
func IsNumericType(t SQLValueType) bool {
217✔
87
        return t == IntegerType || t == Float64Type
217✔
88
}
217✔
89

90
type Permission = string
91

92
const (
93
        PermissionReadOnly  Permission = "READ"
94
        PermissionReadWrite Permission = "READWRITE"
95
        PermissionAdmin     Permission = "ADMIN"
96
)
97

98
type AggregateFn = string
99

100
const (
101
        COUNT AggregateFn = "COUNT"
102
        SUM   AggregateFn = "SUM"
103
        MAX   AggregateFn = "MAX"
104
        MIN   AggregateFn = "MIN"
105
        AVG   AggregateFn = "AVG"
106
)
107

108
type CmpOperator = int
109

110
const (
111
        EQ CmpOperator = iota
112
        NE
113
        LT
114
        LE
115
        GT
116
        GE
117
)
118

119
func CmpOperatorToString(op CmpOperator) string {
19✔
120
        switch op {
19✔
121
        case EQ:
8✔
122
                return "="
8✔
123
        case NE:
2✔
124
                return "!="
2✔
125
        case LT:
1✔
126
                return "<"
1✔
127
        case LE:
4✔
128
                return "<="
4✔
NEW
129
        case GT:
×
NEW
130
                return ">"
×
131
        case GE:
4✔
132
                return ">="
4✔
133
        }
NEW
134
        return ""
×
135
}
136

137
type LogicOperator = int
138

139
const (
140
        AND LogicOperator = iota
141
        OR
142
)
143

144
func LogicOperatorToString(op LogicOperator) string {
7✔
145
        if op == AND {
10✔
146
                return "AND"
3✔
147
        }
3✔
148
        return "OR"
4✔
149
}
150

151
type NumOperator = int
152

153
const (
154
        ADDOP NumOperator = iota
155
        SUBSOP
156
        DIVOP
157
        MULTOP
158
)
159

160
func NumOperatorString(op NumOperator) string {
9✔
161
        switch op {
9✔
162
        case ADDOP:
6✔
163
                return "+"
6✔
164
        case SUBSOP:
1✔
165
                return "-"
1✔
166
        case DIVOP:
1✔
167
                return "/"
1✔
168
        case MULTOP:
1✔
169
                return "*"
1✔
170
        }
NEW
171
        return ""
×
172
}
173

174
type JoinType = int
175

176
const (
177
        InnerJoin JoinType = iota
178
        LeftJoin
179
        RightJoin
180
)
181

182
const (
183
        NowFnCall        string = "NOW"
184
        UUIDFnCall       string = "RANDOM_UUID"
185
        DatabasesFnCall  string = "DATABASES"
186
        TablesFnCall     string = "TABLES"
187
        TableFnCall      string = "TABLE"
188
        UsersFnCall      string = "USERS"
189
        ColumnsFnCall    string = "COLUMNS"
190
        IndexesFnCall    string = "INDEXES"
191
        JSONTypeOfFnCall string = "JSON_TYPEOF"
192
)
193

194
type SQLStmt interface {
195
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
196
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
197
}
198

199
type BeginTransactionStmt struct {
200
}
201

202
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
203
        return nil
3✔
204
}
3✔
205

206
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
60✔
207
        if tx.IsExplicitCloseRequired() {
61✔
208
                return nil, ErrNestedTxNotSupported
1✔
209
        }
1✔
210

211
        err := tx.RequireExplicitClose()
59✔
212
        if err == nil {
117✔
213
                // current tx can be reused as no changes were already made
58✔
214
                return tx, nil
58✔
215
        }
58✔
216

217
        // commit current transaction and start a fresh one
218

219
        err = tx.Commit(ctx)
1✔
220
        if err != nil {
1✔
221
                return nil, err
×
222
        }
×
223

224
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
225
}
226

227
type CommitStmt struct {
228
}
229

230
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
231
        return nil
1✔
232
}
1✔
233

234
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
235
        if !tx.IsExplicitCloseRequired() {
160✔
236
                return nil, ErrNoOngoingTx
1✔
237
        }
1✔
238

239
        return nil, tx.Commit(ctx)
158✔
240
}
241

242
type RollbackStmt struct {
243
}
244

245
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
246
        return nil
1✔
247
}
1✔
248

249
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
250
        if !tx.IsExplicitCloseRequired() {
5✔
251
                return nil, ErrNoOngoingTx
1✔
252
        }
1✔
253

254
        return nil, tx.Cancel()
3✔
255
}
256

257
type CreateDatabaseStmt struct {
258
        DB          string
259
        ifNotExists bool
260
}
261

262
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
263
        return nil
4✔
264
}
4✔
265

266
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
267
        if tx.IsExplicitCloseRequired() {
16✔
268
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
269
        }
1✔
270

271
        if tx.engine.multidbHandler == nil {
16✔
272
                return nil, ErrUnspecifiedMultiDBHandler
2✔
273
        }
2✔
274

275
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
276
}
277

278
type UseDatabaseStmt struct {
279
        DB string
280
}
281

282
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
283
        return nil
1✔
284
}
1✔
285

286
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
287
        if tx.IsExplicitCloseRequired() {
10✔
288
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
289
        }
1✔
290

291
        if tx.engine.multidbHandler == nil {
9✔
292
                return nil, ErrUnspecifiedMultiDBHandler
1✔
293
        }
1✔
294

295
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
296
}
297

298
type UseSnapshotStmt struct {
299
        period period
300
}
301

302
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
303
        return nil
1✔
304
}
1✔
305

306
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
307
        return nil, ErrNoSupported
1✔
308
}
1✔
309

310
type CreateUserStmt struct {
311
        username   string
312
        password   string
313
        permission Permission
314
}
315

316
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
317
        return nil
1✔
318
}
1✔
319

320
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
321
        if tx.IsExplicitCloseRequired() {
7✔
322
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
323
        }
1✔
324

325
        if tx.engine.multidbHandler == nil {
6✔
326
                return nil, ErrUnspecifiedMultiDBHandler
1✔
327
        }
1✔
328

329
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
330
}
331

332
type AlterUserStmt struct {
333
        username   string
334
        password   string
335
        permission Permission
336
}
337

338
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
339
        return nil
1✔
340
}
1✔
341

342
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
343
        if tx.IsExplicitCloseRequired() {
6✔
344
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
345
        }
1✔
346

347
        if tx.engine.multidbHandler == nil {
5✔
348
                return nil, ErrUnspecifiedMultiDBHandler
1✔
349
        }
1✔
350

351
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
352
}
353

354
type DropUserStmt struct {
355
        username string
356
}
357

358
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
359
        return nil
1✔
360
}
1✔
361

362
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
363
        if tx.IsExplicitCloseRequired() {
4✔
364
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
365
        }
1✔
366

367
        if tx.engine.multidbHandler == nil {
3✔
368
                return nil, ErrUnspecifiedMultiDBHandler
1✔
369
        }
1✔
370

371
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
372
}
373

374
type CreateTableStmt struct {
375
        table       string
376
        ifNotExists bool
377
        colsSpec    []*ColSpec
378
        checks      []CheckConstraint
379
        pkColNames  []string
380
}
381

382
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
383
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
384
}
38✔
385

386
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
387
        return nil
4✔
388
}
4✔
389

390
func zeroRow(tableName string, cols []*ColSpec) *Row {
233✔
391
        r := Row{
233✔
392
                ValuesByPosition: make([]TypedValue, len(cols)),
233✔
393
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
233✔
394
        }
233✔
395

233✔
396
        for i, col := range cols {
981✔
397
                v := zeroForType(col.colType)
748✔
398

748✔
399
                r.ValuesByPosition[i] = v
748✔
400
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
748✔
401
        }
748✔
402
        return &r
233✔
403
}
404

405
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
221✔
406
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
222✔
407
                return tx, nil
1✔
408
        }
1✔
409

410
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
220✔
411
        for i, cs := range stmt.colsSpec {
908✔
412
                colSpecs[uint32(i)+1] = cs
688✔
413
        }
688✔
414

415
        row := zeroRow(stmt.table, stmt.colsSpec)
220✔
416
        for _, check := range stmt.checks {
229✔
417
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
418
                if err != nil {
11✔
419
                        return nil, err
2✔
420
                }
2✔
421

422
                if value.Type() != BooleanType {
7✔
NEW
423
                        return nil, ErrInvalidCheckConstraint
×
NEW
424
                }
×
425
        }
426

427
        nextUnnamedCheck := 0
218✔
428
        checks := make(map[string]CheckConstraint)
218✔
429
        for id, check := range stmt.checks {
225✔
430
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
431
                if check.name != "" {
9✔
432
                        name = check.name
2✔
433
                } else {
7✔
434
                        nextUnnamedCheck++
5✔
435
                }
5✔
436
                check.id = uint32(id)
7✔
437
                check.name = name
7✔
438
                checks[name] = check
7✔
439
        }
440

441
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
218✔
442
        if err != nil {
224✔
443
                return nil, err
6✔
444
        }
6✔
445

446
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.pkColNames}
212✔
447
        _, err = createIndexStmt.execAt(ctx, tx, params)
212✔
448
        if err != nil {
217✔
449
                return nil, err
5✔
450
        }
5✔
451

452
        for _, col := range table.cols {
876✔
453
                if col.autoIncrement {
743✔
454
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
75✔
455
                                return nil, ErrLimitedAutoIncrement
1✔
456
                        }
1✔
457
                }
458

459
                err := persistColumn(tx, col)
668✔
460
                if err != nil {
668✔
461
                        return nil, err
×
462
                }
×
463
        }
464

465
        for _, check := range checks {
213✔
466
                if err := persistCheck(tx, table, &check); err != nil {
7✔
NEW
467
                        return nil, err
×
NEW
468
                }
×
469
        }
470

471
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
206✔
472

206✔
473
        err = tx.set(mappedKey, nil, []byte(table.name))
206✔
474
        if err != nil {
206✔
475
                return nil, err
×
476
        }
×
477

478
        tx.mutatedCatalog = true
206✔
479

206✔
480
        return tx, nil
206✔
481
}
482

483
func persistColumn(tx *SQLTx, col *Column) error {
688✔
484
        //{auto_incremental | nullable}{maxLen}{colNAME})
688✔
485
        v := make([]byte, 1+4+len(col.colName))
688✔
486

688✔
487
        if col.autoIncrement {
761✔
488
                v[0] = v[0] | autoIncrementFlag
73✔
489
        }
73✔
490

491
        if col.notNull {
733✔
492
                v[0] = v[0] | nullableFlag
45✔
493
        }
45✔
494

495
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
688✔
496

688✔
497
        copy(v[5:], []byte(col.Name()))
688✔
498

688✔
499
        mappedKey := MapKey(
688✔
500
                tx.sqlPrefix(),
688✔
501
                catalogColumnPrefix,
688✔
502
                EncodeID(DatabaseID),
688✔
503
                EncodeID(col.table.id),
688✔
504
                EncodeID(col.id),
688✔
505
                []byte(col.colType),
688✔
506
        )
688✔
507

688✔
508
        return tx.set(mappedKey, nil, v)
688✔
509
}
510

511
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
512
        mappedKey := MapKey(
7✔
513
                tx.sqlPrefix(),
7✔
514
                catalogCheckPrefix,
7✔
515
                EncodeID(DatabaseID),
7✔
516
                EncodeID(table.id),
7✔
517
                EncodeID(check.id),
7✔
518
        )
7✔
519

7✔
520
        name := check.name
7✔
521
        expText := check.exp.String()
7✔
522

7✔
523
        val := make([]byte, 2+len(name)+len(expText))
7✔
524

7✔
525
        if len(name) > 256 {
7✔
NEW
526
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
NEW
527
        }
×
528

529
        val[0] = byte(len(name)) - 1
7✔
530

7✔
531
        copy(val[1:], []byte(name))
7✔
532
        copy(val[1+len(name):], []byte(expText))
7✔
533

7✔
534
        return tx.set(mappedKey, nil, val)
7✔
535
}
536

537
type ColSpec struct {
538
        colName       string
539
        colType       SQLValueType
540
        maxLen        int
541
        autoIncrement bool
542
        notNull       bool
543
}
544

545
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
546
        return &ColSpec{
188✔
547
                colName:       name,
188✔
548
                colType:       colType,
188✔
549
                maxLen:        maxLen,
188✔
550
                autoIncrement: autoIncrement,
188✔
551
                notNull:       notNull,
188✔
552
        }
188✔
553
}
188✔
554

555
type CreateIndexStmt struct {
556
        unique      bool
557
        ifNotExists bool
558
        table       string
559
        cols        []string
560
}
561

562
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
563
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
564
}
72✔
565

566
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
567
        return nil
1✔
568
}
1✔
569

570
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
363✔
571
        if len(stmt.cols) < 1 {
364✔
572
                return nil, ErrIllegalArguments
1✔
573
        }
1✔
574

575
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
363✔
576
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
577
        }
1✔
578

579
        table, err := tx.catalog.GetTableByName(stmt.table)
361✔
580
        if err != nil {
363✔
581
                return nil, err
2✔
582
        }
2✔
583

584
        colIDs := make([]uint32, len(stmt.cols))
359✔
585

359✔
586
        indexKeyLen := 0
359✔
587

359✔
588
        for i, colName := range stmt.cols {
745✔
589
                col, err := table.GetColumnByName(colName)
386✔
590
                if err != nil {
391✔
591
                        return nil, err
5✔
592
                }
5✔
593

594
                if col.Type() == JSONType {
383✔
595
                        return nil, ErrCannotIndexJson
2✔
596
                }
2✔
597

598
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
381✔
599
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
600
                }
2✔
601

602
                indexKeyLen += col.MaxLen()
377✔
603

377✔
604
                colIDs[i] = col.id
377✔
605
        }
606

607
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
350✔
608
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
609
        }
×
610

611
        if stmt.unique && table.primaryIndex != nil {
370✔
612
                // check table is empty
20✔
613
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
614
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
615
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
616
                        return nil, ErrTableDoesNotExist
×
617
                }
×
618
                if err == nil {
21✔
619
                        return nil, ErrLimitedIndexCreation
1✔
620
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
621
                        return nil, err
×
622
                }
×
623
        }
624

625
        index, err := table.newIndex(stmt.unique, colIDs)
349✔
626
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
351✔
627
                return tx, nil
2✔
628
        }
2✔
629
        if err != nil {
351✔
630
                return nil, err
4✔
631
        }
4✔
632

633
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
634
        // TODO: currently only ASC order is supported
635
        colSpecLen := EncIDLen + 1
343✔
636

343✔
637
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
343✔
638

343✔
639
        if index.IsUnique() {
568✔
640
                encodedValues[0] = 1
225✔
641
        }
225✔
642

643
        for i, col := range index.cols {
713✔
644
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
370✔
645
        }
370✔
646

647
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
343✔
648

343✔
649
        err = tx.set(mappedKey, nil, encodedValues)
343✔
650
        if err != nil {
343✔
651
                return nil, err
×
652
        }
×
653

654
        tx.mutatedCatalog = true
343✔
655

343✔
656
        return tx, nil
343✔
657
}
658

659
type AddColumnStmt struct {
660
        table   string
661
        colSpec *ColSpec
662
}
663

664
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
665
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
666
}
6✔
667

668
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
669
        return nil
1✔
670
}
1✔
671

672
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
673
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
674
        if err != nil {
20✔
675
                return nil, err
1✔
676
        }
1✔
677

678
        col, err := table.newColumn(stmt.colSpec)
18✔
679
        if err != nil {
24✔
680
                return nil, err
6✔
681
        }
6✔
682

683
        err = persistColumn(tx, col)
12✔
684
        if err != nil {
12✔
685
                return nil, err
×
686
        }
×
687

688
        tx.mutatedCatalog = true
12✔
689

12✔
690
        return tx, nil
12✔
691
}
692

693
type RenameTableStmt struct {
694
        oldName string
695
        newName string
696
}
697

698
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
699
        return nil
1✔
700
}
1✔
701

702
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
703
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
704
        if err != nil {
10✔
705
                return nil, err
4✔
706
        }
4✔
707

708
        // update table name
709
        mappedKey := MapKey(
2✔
710
                tx.sqlPrefix(),
2✔
711
                catalogTablePrefix,
2✔
712
                EncodeID(DatabaseID),
2✔
713
                EncodeID(table.id),
2✔
714
        )
2✔
715
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
716
        if err != nil {
2✔
717
                return nil, err
×
718
        }
×
719

720
        tx.mutatedCatalog = true
2✔
721

2✔
722
        return tx, nil
2✔
723
}
724

725
type RenameColumnStmt struct {
726
        table   string
727
        oldName string
728
        newName string
729
}
730

731
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
732
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
733
}
3✔
734

735
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
736
        return nil
1✔
737
}
1✔
738

739
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
740
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
741
        if err != nil {
11✔
742
                return nil, err
1✔
743
        }
1✔
744

745
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
746
        if err != nil {
12✔
747
                return nil, err
3✔
748
        }
3✔
749

750
        err = persistColumn(tx, col)
6✔
751
        if err != nil {
6✔
752
                return nil, err
×
753
        }
×
754

755
        tx.mutatedCatalog = true
6✔
756

6✔
757
        return tx, nil
6✔
758
}
759

760
type DropColumnStmt struct {
761
        table   string
762
        colName string
763
}
764

765
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
766
        return &DropColumnStmt{table: table, colName: colName}
8✔
767
}
8✔
768

769
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
770
        return nil
1✔
771
}
1✔
772

773
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
774
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
775
        if err != nil {
21✔
776
                return nil, err
2✔
777
        }
2✔
778

779
        col, err := table.GetColumnByName(stmt.colName)
17✔
780
        if err != nil {
21✔
781
                return nil, err
4✔
782
        }
4✔
783

784
        err = canDropColumn(tx, table, col)
13✔
785
        if err != nil {
14✔
786
                return nil, err
1✔
787
        }
1✔
788

789
        err = table.deleteColumn(col)
12✔
790
        if err != nil {
16✔
791
                return nil, err
4✔
792
        }
4✔
793

794
        err = persistColumnDeletion(ctx, tx, col)
8✔
795
        if err != nil {
8✔
796
                return nil, err
×
797
        }
×
798

799
        tx.mutatedCatalog = true
8✔
800

8✔
801
        return tx, nil
8✔
802
}
803

804
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
805
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
806
        for _, c := range table.cols {
86✔
807
                if c.id != col.id {
133✔
808
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
809
                }
60✔
810
        }
811

812
        row := zeroRow(table.Name(), colSpecs)
13✔
813
        for name, check := range table.checkConstraints {
20✔
814
                _, err := check.exp.reduce(tx, row, table.name)
7✔
815
                if errors.Is(err, ErrColumnDoesNotExist) {
8✔
816
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
817
                }
1✔
818

819
                if err != nil {
6✔
NEW
820
                        return err
×
NEW
821
                }
×
822
        }
823
        return nil
12✔
824
}
825

826
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
827
        mappedKey := MapKey(
9✔
828
                tx.sqlPrefix(),
9✔
829
                catalogColumnPrefix,
9✔
830
                EncodeID(DatabaseID),
9✔
831
                EncodeID(col.table.id),
9✔
832
                EncodeID(col.id),
9✔
833
                []byte(col.colType),
9✔
834
        )
9✔
835

9✔
836
        return tx.delete(ctx, mappedKey)
9✔
837
}
9✔
838

839
type DropConstraintStmt struct {
840
        table          string
841
        constraintName string
842
}
843

844
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
845
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
846
        if err != nil {
4✔
NEW
847
                return nil, err
×
NEW
848
        }
×
849

850
        id, err := table.deleteCheck(stmt.constraintName)
4✔
851
        if err != nil {
5✔
852
                return nil, err
1✔
853
        }
1✔
854

855
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
856

3✔
857
        tx.mutatedCatalog = true
3✔
858

3✔
859
        return tx, err
3✔
860
}
861

862
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
863
        mappedKey := MapKey(
3✔
864
                tx.sqlPrefix(),
3✔
865
                catalogCheckPrefix,
3✔
866
                EncodeID(DatabaseID),
3✔
867
                EncodeID(tableID),
3✔
868
                EncodeID(checkId),
3✔
869
        )
3✔
870
        return tx.delete(ctx, mappedKey)
3✔
871
}
3✔
872

NEW
873
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
NEW
874
        return nil
×
NEW
875
}
×
876

877
type UpsertIntoStmt struct {
878
        isInsert   bool
879
        tableRef   *tableRef
880
        cols       []string
881
        rows       []*RowSpec
882
        onConflict *OnConflictDo
883
}
884

885
func NewUpserIntoStmt(table string, cols []string, rows []*RowSpec, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
886
        return &UpsertIntoStmt{
120✔
887
                isInsert:   isInsert,
120✔
888
                tableRef:   NewTableRef(table, ""),
120✔
889
                cols:       cols,
120✔
890
                rows:       rows,
120✔
891
                onConflict: onConflict,
120✔
892
        }
120✔
893
}
120✔
894

895
type RowSpec struct {
896
        Values []ValueExp
897
}
898

899
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
900
        return &RowSpec{
129✔
901
                Values: values,
129✔
902
        }
129✔
903
}
129✔
904

905
type OnConflictDo struct{}
906

907
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
908
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
909

11✔
910
        for _, row := range stmt.rows {
23✔
911
                if len(stmt.cols) != len(row.Values) {
13✔
912
                        return ErrInvalidNumberOfValues
1✔
913
                }
1✔
914

915
                for i, val := range row.Values {
36✔
916
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
917
                        if err != nil {
26✔
918
                                return err
1✔
919
                        }
1✔
920

921
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
922
                        if err != nil {
25✔
923
                                return err
1✔
924
                        }
1✔
925

926
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
927
                        if err != nil {
25✔
928
                                return err
2✔
929
                        }
2✔
930
                }
931
        }
932
        return nil
6✔
933
}
934

935
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
1,893✔
936
        selPosByColID := make(map[uint32]int, len(stmt.cols))
1,893✔
937

1,893✔
938
        for i, c := range stmt.cols {
8,070✔
939
                col, err := table.GetColumnByName(c)
6,177✔
940
                if err != nil {
6,179✔
941
                        return nil, err
2✔
942
                }
2✔
943

944
                _, duplicated := selPosByColID[col.id]
6,175✔
945
                if duplicated {
6,176✔
946
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
947
                }
1✔
948

949
                selPosByColID[col.id] = i
6,174✔
950
        }
951

952
        return selPosByColID, nil
1,890✔
953
}
954

955
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1,896✔
956
        table, err := stmt.tableRef.referencedTable(tx)
1,896✔
957
        if err != nil {
1,899✔
958
                return nil, err
3✔
959
        }
3✔
960

961
        selPosByColID, err := stmt.validate(table)
1,893✔
962
        if err != nil {
1,896✔
963
                return nil, err
3✔
964
        }
3✔
965

966
        r := &Row{
1,890✔
967
                ValuesByPosition: make([]TypedValue, len(table.cols)),
1,890✔
968
                ValuesBySelector: make(map[string]TypedValue),
1,890✔
969
        }
1,890✔
970

1,890✔
971
        for _, row := range stmt.rows {
3,868✔
972
                if len(row.Values) != len(stmt.cols) {
1,980✔
973
                        return nil, ErrInvalidNumberOfValues
2✔
974
                }
2✔
975

976
                valuesByColID := make(map[uint32]TypedValue)
1,976✔
977

1,976✔
978
                var pkMustExist bool
1,976✔
979

1,976✔
980
                for colID, col := range table.colsByID {
9,170✔
981
                        colPos, specified := selPosByColID[colID]
7,194✔
982
                        if !specified {
8,042✔
983
                                // TODO: Default values
848✔
984
                                if col.notNull && !col.autoIncrement {
849✔
985
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
986
                                }
1✔
987

988
                                // inject auto-incremental pk value
989
                                if stmt.isInsert && col.autoIncrement {
1,610✔
990
                                        // current implementation assumes only PK can be set as autoincremental
763✔
991
                                        table.maxPK++
763✔
992

763✔
993
                                        pkCol := table.primaryIndex.cols[0]
763✔
994
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
763✔
995

763✔
996
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,433✔
997
                                                tx.firstInsertedPKs[table.name] = table.maxPK
670✔
998
                                        }
670✔
999
                                        tx.lastInsertedPKs[table.name] = table.maxPK
763✔
1000
                                }
1001

1002
                                continue
847✔
1003
                        }
1004

1005
                        // value was specified
1006
                        cVal := row.Values[colPos]
6,346✔
1007

6,346✔
1008
                        val, err := cVal.substitute(params)
6,346✔
1009
                        if err != nil {
6,349✔
1010
                                return nil, err
3✔
1011
                        }
3✔
1012

1013
                        rval, err := val.reduce(tx, nil, table.name)
6,343✔
1014
                        if err != nil {
6,350✔
1015
                                return nil, err
7✔
1016
                        }
7✔
1017

1018
                        if rval.IsNull() {
6,418✔
1019
                                if col.notNull || col.autoIncrement {
82✔
1020
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1021
                                }
×
1022

1023
                                continue
82✔
1024
                        }
1025

1026
                        if col.autoIncrement {
6,273✔
1027
                                // validate specified value
19✔
1028
                                nl, isNumber := rval.RawValue().(int64)
19✔
1029
                                if !isNumber {
19✔
1030
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1031
                                }
×
1032

1033
                                pkMustExist = nl <= table.maxPK
19✔
1034

19✔
1035
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1036
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1037
                                }
19✔
1038
                                tx.lastInsertedPKs[table.name] = nl
19✔
1039
                        }
1040

1041
                        valuesByColID[colID] = rval
6,254✔
1042
                }
1043

1044
                for i, col := range table.cols {
9,131✔
1045
                        v := valuesByColID[col.id]
7,166✔
1046

7,166✔
1047
                        if v == nil {
7,328✔
1048
                                v = NewNull(AnyType)
162✔
1049
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
7,171✔
1050
                                s, _ := v.RawValue().(string)
5✔
1051
                                jsonVal, err := NewJsonFromString(s)
5✔
1052
                                if err != nil {
5✔
NEW
1053
                                        return nil, err
×
NEW
1054
                                }
×
1055
                                v = jsonVal
5✔
1056
                        }
1057

1058
                        r.ValuesByPosition[i] = v
7,166✔
1059
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
7,166✔
1060
                }
1061

1062
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
1,971✔
1063
                        return nil, err
6✔
1064
                }
6✔
1065

1066
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
1,959✔
1067
                if err != nil {
1,964✔
1068
                        return nil, err
5✔
1069
                }
5✔
1070

1071
                // pk entry
1072
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
1,954✔
1073
                if len(mappedPKey) > MaxKeyLen {
1,954✔
1074
                        return nil, ErrMaxKeyLengthExceeded
×
1075
                }
×
1076

1077
                _, err = tx.get(ctx, mappedPKey)
1,954✔
1078
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
1,954✔
1079
                        return nil, err
×
1080
                }
×
1081

1082
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
1,956✔
1083
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1084
                }
2✔
1085

1086
                if stmt.isInsert {
3,723✔
1087
                        if err == nil && stmt.onConflict == nil {
1,775✔
1088
                                return nil, store.ErrKeyAlreadyExists
4✔
1089
                        }
4✔
1090

1091
                        if err == nil && stmt.onConflict != nil {
1,770✔
1092
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1093
                                continue
3✔
1094
                        }
1095
                }
1096

1097
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
1,945✔
1098
                if err != nil {
1,958✔
1099
                        return nil, err
13✔
1100
                }
13✔
1101
        }
1102

1103
        return tx, nil
1,847✔
1104
}
1105

1106
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
1,999✔
1107
        for _, check := range checks {
2,044✔
1108
                val, err := check.exp.reduce(tx, row, table)
45✔
1109
                if err != nil {
46✔
1110
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1111
                }
1✔
1112

1113
                if val.Type() != BooleanType {
44✔
NEW
1114
                        return ErrInvalidCheckConstraint
×
NEW
1115
                }
×
1116

1117
                if !val.RawValue().(bool) {
51✔
1118
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1119
                }
7✔
1120
        }
1121
        return nil
1,991✔
1122
}
1123

1124
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,134✔
1125
        valbuf := bytes.Buffer{}
2,134✔
1126

2,134✔
1127
        // null values are not serialized
2,134✔
1128
        encodedVals := 0
2,134✔
1129
        for _, v := range valuesByColID {
9,559✔
1130
                if !v.IsNull() {
14,832✔
1131
                        encodedVals++
7,407✔
1132
                }
7,407✔
1133
        }
1134

1135
        b := make([]byte, EncLenLen)
2,134✔
1136
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,134✔
1137

2,134✔
1138
        _, err := valbuf.Write(b)
2,134✔
1139
        if err != nil {
2,134✔
1140
                return nil, err
×
1141
        }
×
1142

1143
        for _, col := range table.cols {
9,697✔
1144
                rval, specified := valuesByColID[col.id]
7,563✔
1145
                if !specified || rval.IsNull() {
7,725✔
1146
                        continue
162✔
1147
                }
1148

1149
                b := make([]byte, EncIDLen)
7,401✔
1150
                binary.BigEndian.PutUint32(b, uint32(col.id))
7,401✔
1151

7,401✔
1152
                _, err = valbuf.Write(b)
7,401✔
1153
                if err != nil {
7,401✔
1154
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1155
                }
×
1156

1157
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
7,401✔
1158
                if err != nil {
7,409✔
1159
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1160
                }
8✔
1161

1162
                _, err = valbuf.Write(encVal)
7,393✔
1163
                if err != nil {
7,393✔
1164
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1165
                }
×
1166
        }
1167

1168
        return valbuf.Bytes(), nil
2,126✔
1169
}
1170

1171
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
1,977✔
1172
        var reusableIndexEntries map[uint32]struct{}
1,977✔
1173

1,977✔
1174
        if reuseIndex && len(table.indexes) > 1 {
2,034✔
1175
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1176
                if err == nil {
93✔
1177
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1178

36✔
1179
                        for _, col := range table.cols {
161✔
1180
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1181
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1182
                        }
125✔
1183

1184
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1185
                        if err != nil {
36✔
1186
                                return err
×
1187
                        }
×
1188
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1189
                        return err
×
1190
                }
×
1191
        }
1192

1193
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
1,977✔
1194

1,977✔
1195
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
1,977✔
1196
        if err != nil {
1,985✔
1197
                return err
8✔
1198
        }
8✔
1199

1200
        err = tx.set(rowKey, nil, encodedRowValue)
1,969✔
1201
        if err != nil {
1,969✔
1202
                return err
×
1203
        }
×
1204

1205
        // create in-memory and validate entries for secondary indexes
1206
        for _, index := range table.indexes {
4,821✔
1207
                if index.IsPrimary() {
4,821✔
1208
                        continue
1,969✔
1209
                }
1210

1211
                if reusableIndexEntries != nil {
960✔
1212
                        _, reusable := reusableIndexEntries[index.id]
77✔
1213
                        if reusable {
127✔
1214
                                continue
50✔
1215
                        }
1216
                }
1217

1218
                encodedValues := make([][]byte, 2+len(index.cols))
833✔
1219
                encodedValues[0] = EncodeID(table.id)
833✔
1220
                encodedValues[1] = EncodeID(index.id)
833✔
1221

833✔
1222
                indexKeyLen := 0
833✔
1223

833✔
1224
                for i, col := range index.cols {
1,729✔
1225
                        rval, specified := valuesByColID[col.id]
896✔
1226
                        if !specified {
945✔
1227
                                rval = &NullValue{t: col.colType}
49✔
1228
                        }
49✔
1229

1230
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
896✔
1231
                        if err != nil {
896✔
1232
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1233
                        }
×
1234

1235
                        if n > MaxKeyLen {
896✔
1236
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1237
                        }
×
1238

1239
                        indexKeyLen += n
896✔
1240

896✔
1241
                        encodedValues[i+2] = encVal
896✔
1242
                }
1243

1244
                if indexKeyLen > MaxKeyLen {
833✔
1245
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1246
                }
×
1247

1248
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
833✔
1249

833✔
1250
                // no other equivalent entry should be already indexed
833✔
1251
                if index.IsUnique() {
910✔
1252
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1253
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1254
                                return store.ErrKeyAlreadyExists
5✔
1255
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1256
                                return err
×
1257
                        }
×
1258
                }
1259

1260
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
828✔
1261
                if err != nil {
828✔
1262
                        return err
×
1263
                }
×
1264
        }
1265

1266
        tx.updatedRows++
1,964✔
1267

1,964✔
1268
        return nil
1,964✔
1269
}
1270

1271
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
12,886✔
1272
        valbuf := bytes.Buffer{}
12,886✔
1273

12,886✔
1274
        indexKeyLen := 0
12,886✔
1275

12,886✔
1276
        for _, col := range index.cols {
25,784✔
1277
                rval, specified := valuesByColID[col.id]
12,898✔
1278
                if !specified || rval.IsNull() {
12,901✔
1279
                        return nil, ErrPKCanNotBeNull
3✔
1280
                }
3✔
1281

1282
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
12,895✔
1283
                if err != nil {
12,897✔
1284
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1285
                }
2✔
1286

1287
                if n > MaxKeyLen {
12,893✔
1288
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1289
                }
×
1290

1291
                indexKeyLen += n
12,893✔
1292

12,893✔
1293
                _, err = valbuf.Write(encVal)
12,893✔
1294
                if err != nil {
12,893✔
1295
                        return nil, err
×
1296
                }
×
1297
        }
1298

1299
        if indexKeyLen > MaxKeyLen {
12,881✔
1300
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1301
        }
×
1302

1303
        return valbuf.Bytes(), nil
12,881✔
1304
}
1305

1306
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1307
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1308

57✔
1309
        for _, pkCol := range table.primaryIndex.cols {
114✔
1310
                pkVal := valuesByColID[pkCol.id]
57✔
1311

57✔
1312
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1313
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1314
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1315
                }
57✔
1316
        }
57✔
1317

1318
        scanSpecs := &ScanSpecs{
57✔
1319
                Index:         table.primaryIndex,
57✔
1320
                rangesByColID: pkRanges,
57✔
1321
        }
57✔
1322

57✔
1323
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1324
        if err != nil {
57✔
1325
                return nil, err
×
1326
        }
×
1327

1328
        defer func() {
114✔
1329
                r.Close()
57✔
1330
        }()
57✔
1331

1332
        return r.Read(ctx)
57✔
1333
}
1334

1335
// deprecateIndexEntries mark previous index entries as deleted
1336
func (tx *SQLTx) deprecateIndexEntries(
1337
        pkEncVals []byte,
1338
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1339
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1340

36✔
1341
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1342
        if err != nil {
36✔
1343
                return nil, err
×
1344
        }
×
1345

1346
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1347

36✔
1348
        for _, index := range table.indexes {
149✔
1349
                if index.IsPrimary() {
149✔
1350
                        continue
36✔
1351
                }
1352

1353
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1354
                encodedValues[0] = EncodeID(table.id)
77✔
1355
                encodedValues[1] = EncodeID(index.id)
77✔
1356
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1357

77✔
1358
                // existent index entry is deleted only if it differs from existent one
77✔
1359
                sameIndexKey := true
77✔
1360

77✔
1361
                for i, col := range index.cols {
159✔
1362
                        currVal, specified := currValuesByColID[col.id]
82✔
1363
                        if !specified {
82✔
1364
                                currVal = &NullValue{t: col.colType}
×
1365
                        }
×
1366

1367
                        newVal, specified := newValuesByColID[col.id]
82✔
1368
                        if !specified {
86✔
1369
                                newVal = &NullValue{t: col.colType}
4✔
1370
                        }
4✔
1371

1372
                        r, err := currVal.Compare(newVal)
82✔
1373
                        if err != nil {
82✔
1374
                                return nil, err
×
1375
                        }
×
1376

1377
                        sameIndexKey = sameIndexKey && r == 0
82✔
1378

82✔
1379
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1380

82✔
1381
                        encodedValues[i+3] = encVal
82✔
1382
                }
1383

1384
                // mark existent index entry as deleted
1385
                if sameIndexKey {
127✔
1386
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1387
                } else {
77✔
1388
                        md := store.NewKVMetadata()
27✔
1389

27✔
1390
                        md.AsDeleted(true)
27✔
1391

27✔
1392
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1393
                        if err != nil {
27✔
1394
                                return nil, err
×
1395
                        }
×
1396
                }
1397
        }
1398

1399
        return reusableIndexEntries, nil
36✔
1400
}
1401

1402
type UpdateStmt struct {
1403
        tableRef *tableRef
1404
        where    ValueExp
1405
        updates  []*colUpdate
1406
        indexOn  []string
1407
        limit    ValueExp
1408
        offset   ValueExp
1409
}
1410

1411
type colUpdate struct {
1412
        col string
1413
        op  CmpOperator
1414
        val ValueExp
1415
}
1416

1417
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1418
        selectStmt := &SelectStmt{
1✔
1419
                ds:    stmt.tableRef,
1✔
1420
                where: stmt.where,
1✔
1421
        }
1✔
1422

1✔
1423
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1424
        if err != nil {
1✔
1425
                return err
×
1426
        }
×
1427

1428
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1429
        if err != nil {
1✔
1430
                return err
×
1431
        }
×
1432

1433
        for _, update := range stmt.updates {
2✔
1434
                col, err := table.GetColumnByName(update.col)
1✔
1435
                if err != nil {
1✔
1436
                        return err
×
1437
                }
×
1438

1439
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1440
                if err != nil {
1✔
1441
                        return err
×
1442
                }
×
1443
        }
1444

1445
        return nil
1✔
1446
}
1447

1448
func (stmt *UpdateStmt) validate(table *Table) error {
21✔
1449
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
21✔
1450

21✔
1451
        for _, update := range stmt.updates {
44✔
1452
                if update.op != EQ {
23✔
1453
                        return ErrIllegalArguments
×
1454
                }
×
1455

1456
                col, err := table.GetColumnByName(update.col)
23✔
1457
                if err != nil {
24✔
1458
                        return err
1✔
1459
                }
1✔
1460

1461
                if table.PrimaryIndex().IncludesCol(col.id) {
22✔
1462
                        return ErrPKCanNotBeUpdated
×
1463
                }
×
1464

1465
                _, duplicated := colIDs[col.id]
22✔
1466
                if duplicated {
22✔
1467
                        return ErrDuplicatedColumn
×
1468
                }
×
1469

1470
                colIDs[col.id] = struct{}{}
22✔
1471
        }
1472

1473
        return nil
20✔
1474
}
1475

1476
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
22✔
1477
        selectStmt := &SelectStmt{
22✔
1478
                ds:      stmt.tableRef,
22✔
1479
                where:   stmt.where,
22✔
1480
                indexOn: stmt.indexOn,
22✔
1481
                limit:   stmt.limit,
22✔
1482
                offset:  stmt.offset,
22✔
1483
        }
22✔
1484

22✔
1485
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
22✔
1486
        if err != nil {
23✔
1487
                return nil, err
1✔
1488
        }
1✔
1489
        defer rowReader.Close()
21✔
1490

21✔
1491
        table := rowReader.ScanSpecs().Index.table
21✔
1492

21✔
1493
        err = stmt.validate(table)
21✔
1494
        if err != nil {
22✔
1495
                return nil, err
1✔
1496
        }
1✔
1497

1498
        cols, err := rowReader.colsBySelector(ctx)
20✔
1499
        if err != nil {
20✔
1500
                return nil, err
×
1501
        }
×
1502

1503
        for {
72✔
1504
                row, err := rowReader.Read(ctx)
52✔
1505
                if errors.Is(err, ErrNoMoreRows) {
69✔
1506
                        break
17✔
1507
                } else if err != nil {
36✔
1508
                        return nil, err
1✔
1509
                }
1✔
1510

1511
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
34✔
1512

34✔
1513
                for _, col := range table.cols {
127✔
1514
                        encSel := EncodeSelector("", table.name, col.colName)
93✔
1515
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
93✔
1516
                }
93✔
1517

1518
                for _, update := range stmt.updates {
70✔
1519
                        col, err := table.GetColumnByName(update.col)
36✔
1520
                        if err != nil {
36✔
1521
                                return nil, err
×
1522
                        }
×
1523

1524
                        sval, err := update.val.substitute(params)
36✔
1525
                        if err != nil {
36✔
1526
                                return nil, err
×
1527
                        }
×
1528

1529
                        rval, err := sval.reduce(tx, row, table.name)
36✔
1530
                        if err != nil {
36✔
1531
                                return nil, err
×
1532
                        }
×
1533

1534
                        err = rval.requiresType(col.colType, cols, nil, table.name)
36✔
1535
                        if err != nil {
36✔
1536
                                return nil, err
×
1537
                        }
×
1538

1539
                        valuesByColID[col.id] = rval
36✔
1540
                }
1541

1542
                for i, col := range table.cols {
127✔
1543
                        v := valuesByColID[col.id]
93✔
1544

93✔
1545
                        row.ValuesByPosition[i] = v
93✔
1546
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
93✔
1547
                }
93✔
1548

1549
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
36✔
1550
                        return nil, err
2✔
1551
                }
2✔
1552

1553
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
32✔
1554
                if err != nil {
32✔
1555
                        return nil, err
×
1556
                }
×
1557

1558
                // primary index entry
1559
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
32✔
1560

32✔
1561
                // mkey must exist
32✔
1562
                _, err = tx.get(ctx, mkey)
32✔
1563
                if err != nil {
32✔
1564
                        return nil, err
×
1565
                }
×
1566

1567
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
32✔
1568
                if err != nil {
32✔
1569
                        return nil, err
×
1570
                }
×
1571
        }
1572

1573
        return tx, nil
17✔
1574
}
1575

1576
type DeleteFromStmt struct {
1577
        tableRef *tableRef
1578
        where    ValueExp
1579
        indexOn  []string
1580
        orderBy  []*OrdCol
1581
        limit    ValueExp
1582
        offset   ValueExp
1583
}
1584

1585
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdCol, limit ValueExp) *DeleteFromStmt {
4✔
1586
        return &DeleteFromStmt{
4✔
1587
                tableRef: NewTableRef(table, ""),
4✔
1588
                where:    where,
4✔
1589
                orderBy:  orderBy,
4✔
1590
                limit:    limit,
4✔
1591
        }
4✔
1592
}
4✔
1593

1594
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1595
        selectStmt := &SelectStmt{
1✔
1596
                ds:      stmt.tableRef,
1✔
1597
                where:   stmt.where,
1✔
1598
                orderBy: stmt.orderBy,
1✔
1599
        }
1✔
1600
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1601
}
1✔
1602

1603
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
1604
        selectStmt := &SelectStmt{
15✔
1605
                ds:      stmt.tableRef,
15✔
1606
                where:   stmt.where,
15✔
1607
                indexOn: stmt.indexOn,
15✔
1608
                orderBy: stmt.orderBy,
15✔
1609
                limit:   stmt.limit,
15✔
1610
                offset:  stmt.offset,
15✔
1611
        }
15✔
1612

15✔
1613
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
15✔
1614
        if err != nil {
17✔
1615
                return nil, err
2✔
1616
        }
2✔
1617
        defer rowReader.Close()
13✔
1618

13✔
1619
        table := rowReader.ScanSpecs().Index.table
13✔
1620

13✔
1621
        for {
147✔
1622
                row, err := rowReader.Read(ctx)
134✔
1623
                if errors.Is(err, ErrNoMoreRows) {
146✔
1624
                        break
12✔
1625
                }
1626
                if err != nil {
123✔
1627
                        return nil, err
1✔
1628
                }
1✔
1629

1630
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
121✔
1631

121✔
1632
                for _, col := range table.cols {
406✔
1633
                        encSel := EncodeSelector("", table.name, col.colName)
285✔
1634
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
285✔
1635
                }
285✔
1636

1637
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
121✔
1638
                if err != nil {
121✔
1639
                        return nil, err
×
1640
                }
×
1641

1642
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
121✔
1643
                if err != nil {
121✔
1644
                        return nil, err
×
1645
                }
×
1646

1647
                tx.updatedRows++
121✔
1648
        }
1649

1650
        return tx, nil
12✔
1651
}
1652

1653
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
121✔
1654
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
121✔
1655
        if err != nil {
121✔
1656
                return err
×
1657
        }
×
1658

1659
        for _, index := range table.indexes {
291✔
1660
                if !index.IsPrimary() {
219✔
1661
                        continue
49✔
1662
                }
1663

1664
                encodedValues := make([][]byte, 3+len(index.cols))
121✔
1665
                encodedValues[0] = EncodeID(DatabaseID)
121✔
1666
                encodedValues[1] = EncodeID(table.id)
121✔
1667
                encodedValues[2] = EncodeID(index.id)
121✔
1668

121✔
1669
                for i, col := range index.cols {
242✔
1670
                        val, specified := valuesByColID[col.id]
121✔
1671
                        if !specified {
121✔
1672
                                val = &NullValue{t: col.colType}
×
1673
                        }
×
1674

1675
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
121✔
1676

121✔
1677
                        encodedValues[i+3] = encVal
121✔
1678
                }
1679

1680
                md := store.NewKVMetadata()
121✔
1681

121✔
1682
                md.AsDeleted(true)
121✔
1683

121✔
1684
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
121✔
1685
                if err != nil {
121✔
1686
                        return err
×
1687
                }
×
1688
        }
1689

1690
        return nil
121✔
1691
}
1692

1693
type ValueExp interface {
1694
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1695
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1696
        substitute(params map[string]interface{}) (ValueExp, error)
1697
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1698
        reduceSelectors(row *Row, implicitTable string) ValueExp
1699
        isConstant() bool
1700
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1701
        String() string
1702
}
1703

1704
type typedValueRange struct {
1705
        lRange *typedValueSemiRange
1706
        hRange *typedValueSemiRange
1707
}
1708

1709
type typedValueSemiRange struct {
1710
        val       TypedValue
1711
        inclusive bool
1712
}
1713

1714
func (r *typedValueRange) unitary() bool {
19✔
1715
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1716
        if r.lRange == nil || r.hRange == nil {
19✔
1717
                return false
×
1718
        }
×
1719

1720
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1721
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1722
}
1723

1724
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
3✔
1725
        if r.lRange == nil {
4✔
1726
                r.lRange = refiningRange.lRange
1✔
1727
        } else if r.lRange != nil && refiningRange.lRange != nil {
4✔
1728
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1729
                if err != nil {
1✔
1730
                        return err
×
1731
                }
×
1732
                r.lRange = maxRange
1✔
1733
        }
1734

1735
        if r.hRange == nil {
4✔
1736
                r.hRange = refiningRange.hRange
1✔
1737
        } else if r.hRange != nil && refiningRange.hRange != nil {
5✔
1738
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1739
                if err != nil {
2✔
1740
                        return err
×
1741
                }
×
1742
                r.hRange = minRange
2✔
1743
        }
1744

1745
        return nil
3✔
1746
}
1747

1748
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1749
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1750
                r.lRange = nil
2✔
1751
        } else {
5✔
1752
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1753
                if err != nil {
3✔
1754
                        return err
×
1755
                }
×
1756
                r.lRange = minRange
3✔
1757
        }
1758

1759
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1760
                r.hRange = nil
3✔
1761
        } else {
5✔
1762
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1763
                if err != nil {
2✔
1764
                        return err
×
1765
                }
×
1766
                r.hRange = maxRange
2✔
1767
        }
1768

1769
        return nil
5✔
1770
}
1771

1772
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
1773
        r, err := or1.val.Compare(or2.val)
3✔
1774
        if err != nil {
3✔
1775
                return nil, err
×
1776
        }
×
1777

1778
        maxVal := or1.val
3✔
1779
        if r < 0 {
5✔
1780
                maxVal = or2.val
2✔
1781
        }
2✔
1782

1783
        return &typedValueSemiRange{
3✔
1784
                val:       maxVal,
3✔
1785
                inclusive: or1.inclusive && or2.inclusive,
3✔
1786
        }, nil
3✔
1787
}
1788

1789
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
1790
        r, err := or1.val.Compare(or2.val)
5✔
1791
        if err != nil {
5✔
1792
                return nil, err
×
1793
        }
×
1794

1795
        minVal := or1.val
5✔
1796
        if r > 0 {
9✔
1797
                minVal = or2.val
4✔
1798
        }
4✔
1799

1800
        return &typedValueSemiRange{
5✔
1801
                val:       minVal,
5✔
1802
                inclusive: or1.inclusive || or2.inclusive,
5✔
1803
        }, nil
5✔
1804
}
1805

1806
type TypedValue interface {
1807
        ValueExp
1808
        Type() SQLValueType
1809
        RawValue() interface{}
1810
        Compare(val TypedValue) (int, error)
1811
        IsNull() bool
1812
}
1813

1814
type Tuple []TypedValue
1815

1816
func (t Tuple) Compare(other Tuple) (int, int, error) {
85,700✔
1817
        if len(t) != len(other) {
85,700✔
1818
                return -1, -1, ErrNotComparableValues
×
1819
        }
×
1820

1821
        for i := range t {
187,354✔
1822
                res, err := t[i].Compare(other[i])
101,654✔
1823
                if err != nil || res != 0 {
179,729✔
1824
                        return res, i, err
78,075✔
1825
                }
78,075✔
1826
        }
1827
        return 0, -1, nil
7,625✔
1828
}
1829

1830
func NewNull(t SQLValueType) *NullValue {
259✔
1831
        return &NullValue{t: t}
259✔
1832
}
259✔
1833

1834
type NullValue struct {
1835
        t SQLValueType
1836
}
1837

1838
func (n *NullValue) Type() SQLValueType {
62✔
1839
        return n.t
62✔
1840
}
62✔
1841

1842
func (n *NullValue) RawValue() interface{} {
247✔
1843
        return nil
247✔
1844
}
247✔
1845

1846
func (n *NullValue) IsNull() bool {
326✔
1847
        return true
326✔
1848
}
326✔
1849

1850
func (n *NullValue) String() string {
4✔
1851
        return "NULL"
4✔
1852
}
4✔
1853

1854
func (n *NullValue) Compare(val TypedValue) (int, error) {
76✔
1855
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
77✔
1856
                return 0, ErrNotComparableValues
1✔
1857
        }
1✔
1858

1859
        if val.RawValue() == nil {
107✔
1860
                return 0, nil
32✔
1861
        }
32✔
1862
        return -1, nil
43✔
1863
}
1864

1865
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
5✔
1866
        return v.t, nil
5✔
1867
}
5✔
1868

1869
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
9✔
1870
        if v.t == t {
14✔
1871
                return nil
5✔
1872
        }
5✔
1873

1874
        if v.t != AnyType {
5✔
1875
                return ErrInvalidTypes
1✔
1876
        }
1✔
1877

1878
        v.t = t
3✔
1879

3✔
1880
        return nil
3✔
1881
}
1882

1883
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
280✔
1884
        return v, nil
280✔
1885
}
280✔
1886

1887
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
221✔
1888
        return v, nil
221✔
1889
}
221✔
1890

1891
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
1892
        return v
10✔
1893
}
10✔
1894

1895
func (v *NullValue) isConstant() bool {
12✔
1896
        return true
12✔
1897
}
12✔
1898

1899
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
1900
        return nil
1✔
1901
}
1✔
1902

1903
type Integer struct {
1904
        val int64
1905
}
1906

1907
func NewInteger(val int64) *Integer {
293✔
1908
        return &Integer{val: val}
293✔
1909
}
293✔
1910

1911
func (v *Integer) Type() SQLValueType {
164,344✔
1912
        return IntegerType
164,344✔
1913
}
164,344✔
1914

1915
func (v *Integer) IsNull() bool {
72,279✔
1916
        return false
72,279✔
1917
}
72,279✔
1918

1919
func (v *Integer) String() string {
31✔
1920
        return strconv.FormatInt(v.val, 10)
31✔
1921
}
31✔
1922

1923
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
19✔
1924
        return IntegerType, nil
19✔
1925
}
19✔
1926

1927
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
23✔
1928
        if t != IntegerType && t != JSONType {
27✔
1929
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
1930
        }
4✔
1931

1932
        return nil
19✔
1933
}
1934

1935
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
2,172✔
1936
        return v, nil
2,172✔
1937
}
2,172✔
1938

1939
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,918✔
1940
        return v, nil
3,918✔
1941
}
3,918✔
1942

1943
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
6✔
1944
        return v
6✔
1945
}
6✔
1946

1947
func (v *Integer) isConstant() bool {
116✔
1948
        return true
116✔
1949
}
116✔
1950

1951
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
1952
        return nil
1✔
1953
}
1✔
1954

1955
func (v *Integer) RawValue() interface{} {
110,798✔
1956
        return v.val
110,798✔
1957
}
110,798✔
1958

1959
func (v *Integer) Compare(val TypedValue) (int, error) {
51,023✔
1960
        if val.IsNull() {
51,066✔
1961
                return 1, nil
43✔
1962
        }
43✔
1963

1964
        if val.Type() == JSONType {
50,981✔
1965
                res, err := val.Compare(v)
1✔
1966
                return -res, err
1✔
1967
        }
1✔
1968

1969
        if val.Type() == Float64Type {
50,979✔
1970
                r, err := val.Compare(v)
×
1971
                return r * -1, err
×
1972
        }
×
1973

1974
        if val.Type() != IntegerType {
50,986✔
1975
                return 0, ErrNotComparableValues
7✔
1976
        }
7✔
1977

1978
        rval := val.RawValue().(int64)
50,972✔
1979

50,972✔
1980
        if v.val == rval {
63,345✔
1981
                return 0, nil
12,373✔
1982
        }
12,373✔
1983

1984
        if v.val > rval {
57,712✔
1985
                return 1, nil
19,113✔
1986
        }
19,113✔
1987

1988
        return -1, nil
19,486✔
1989
}
1990

1991
type Timestamp struct {
1992
        val time.Time
1993
}
1994

1995
func (v *Timestamp) Type() SQLValueType {
19,456✔
1996
        return TimestampType
19,456✔
1997
}
19,456✔
1998

1999
func (v *Timestamp) IsNull() bool {
17,763✔
2000
        return false
17,763✔
2001
}
17,763✔
2002

2003
func (v *Timestamp) String() string {
1✔
2004
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2005
}
1✔
2006

2007
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2008
        return TimestampType, nil
1✔
2009
}
1✔
2010

2011
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2012
        if t != TimestampType {
3✔
2013
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2014
        }
1✔
2015

2016
        return nil
1✔
2017
}
2018

2019
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
120✔
2020
        return v, nil
120✔
2021
}
120✔
2022

2023
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
969✔
2024
        return v, nil
969✔
2025
}
969✔
2026

2027
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2028
        return v
1✔
2029
}
1✔
2030

2031
func (v *Timestamp) isConstant() bool {
1✔
2032
        return true
1✔
2033
}
1✔
2034

2035
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2036
        return nil
1✔
2037
}
1✔
2038

2039
func (v *Timestamp) RawValue() interface{} {
31,773✔
2040
        return v.val
31,773✔
2041
}
31,773✔
2042

2043
func (v *Timestamp) Compare(val TypedValue) (int, error) {
14,979✔
2044
        if val.IsNull() {
14,981✔
2045
                return 1, nil
2✔
2046
        }
2✔
2047

2048
        if val.Type() != TimestampType {
14,978✔
2049
                return 0, ErrNotComparableValues
1✔
2050
        }
1✔
2051

2052
        rval := val.RawValue().(time.Time)
14,976✔
2053

14,976✔
2054
        if v.val.Before(rval) {
22,214✔
2055
                return -1, nil
7,238✔
2056
        }
7,238✔
2057

2058
        if v.val.After(rval) {
15,397✔
2059
                return 1, nil
7,659✔
2060
        }
7,659✔
2061

2062
        return 0, nil
79✔
2063
}
2064

2065
type Varchar struct {
2066
        val string
2067
}
2068

2069
func NewVarchar(val string) *Varchar {
1,543✔
2070
        return &Varchar{val: val}
1,543✔
2071
}
1,543✔
2072

2073
func (v *Varchar) Type() SQLValueType {
84,357✔
2074
        return VarcharType
84,357✔
2075
}
84,357✔
2076

2077
func (v *Varchar) IsNull() bool {
44,484✔
2078
        return false
44,484✔
2079
}
44,484✔
2080

2081
func (v *Varchar) String() string {
8✔
2082
        return fmt.Sprintf("'%s'", v.val)
8✔
2083
}
8✔
2084

2085
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
16✔
2086
        return VarcharType, nil
16✔
2087
}
16✔
2088

2089
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
104✔
2090
        if t != VarcharType && t != JSONType {
106✔
2091
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
2✔
2092
        }
2✔
2093
        return nil
102✔
2094
}
2095

2096
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
1,571✔
2097
        return v, nil
1,571✔
2098
}
1,571✔
2099

2100
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,522✔
2101
        return v, nil
2,522✔
2102
}
2,522✔
2103

2104
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2105
        return v
×
2106
}
×
2107

2108
func (v *Varchar) isConstant() bool {
38✔
2109
        return true
38✔
2110
}
38✔
2111

2112
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2113
        return nil
1✔
2114
}
1✔
2115

2116
func (v *Varchar) RawValue() interface{} {
59,378✔
2117
        return v.val
59,378✔
2118
}
59,378✔
2119

2120
func (v *Varchar) Compare(val TypedValue) (int, error) {
39,678✔
2121
        if val.IsNull() {
39,708✔
2122
                return 1, nil
30✔
2123
        }
30✔
2124

2125
        if val.Type() == JSONType {
40,649✔
2126
                res, err := val.Compare(v)
1,001✔
2127
                return -res, err
1,001✔
2128
        }
1,001✔
2129

2130
        if val.Type() != VarcharType {
38,648✔
2131
                return 0, ErrNotComparableValues
1✔
2132
        }
1✔
2133

2134
        rval := val.RawValue().(string)
38,646✔
2135

38,646✔
2136
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
38,646✔
2137
}
2138

2139
type UUID struct {
2140
        val uuid.UUID
2141
}
2142

2143
func NewUUID(val uuid.UUID) *UUID {
1✔
2144
        return &UUID{val: val}
1✔
2145
}
1✔
2146

2147
func (v *UUID) Type() SQLValueType {
10✔
2148
        return UUIDType
10✔
2149
}
10✔
2150

2151
func (v *UUID) IsNull() bool {
26✔
2152
        return false
26✔
2153
}
26✔
2154

2155
func (v *UUID) String() string {
1✔
2156
        return v.val.String()
1✔
2157
}
1✔
2158

2159
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2160
        return UUIDType, nil
1✔
2161
}
1✔
2162

2163
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2164
        if t != UUIDType {
6✔
2165
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2166
        }
2✔
2167

2168
        return nil
2✔
2169
}
2170

2171
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
2✔
2172
        return v, nil
2✔
2173
}
2✔
2174

2175
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
1✔
2176
        return v, nil
1✔
2177
}
1✔
2178

2179
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2180
        return v
1✔
2181
}
1✔
2182

2183
func (v *UUID) isConstant() bool {
1✔
2184
        return true
1✔
2185
}
1✔
2186

2187
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2188
        return nil
1✔
2189
}
1✔
2190

2191
func (v *UUID) RawValue() interface{} {
40✔
2192
        return v.val
40✔
2193
}
40✔
2194

2195
func (v *UUID) Compare(val TypedValue) (int, error) {
5✔
2196
        if val.IsNull() {
7✔
2197
                return 1, nil
2✔
2198
        }
2✔
2199

2200
        if val.Type() != UUIDType {
4✔
2201
                return 0, ErrNotComparableValues
1✔
2202
        }
1✔
2203

2204
        rval := val.RawValue().(uuid.UUID)
2✔
2205

2✔
2206
        return bytes.Compare(v.val[:], rval[:]), nil
2✔
2207
}
2208

2209
type Bool struct {
2210
        val bool
2211
}
2212

2213
func NewBool(val bool) *Bool {
106✔
2214
        return &Bool{val: val}
106✔
2215
}
106✔
2216

2217
func (v *Bool) Type() SQLValueType {
610✔
2218
        return BooleanType
610✔
2219
}
610✔
2220

2221
func (v *Bool) IsNull() bool {
1,018✔
2222
        return false
1,018✔
2223
}
1,018✔
2224

2225
func (v *Bool) String() string {
5✔
2226
        return strconv.FormatBool(v.val)
5✔
2227
}
5✔
2228

2229
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
24✔
2230
        return BooleanType, nil
24✔
2231
}
24✔
2232

2233
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
49✔
2234
        if t != BooleanType && t != JSONType {
54✔
2235
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2236
        }
5✔
2237
        return nil
44✔
2238
}
2239

2240
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
394✔
2241
        return v, nil
394✔
2242
}
394✔
2243

2244
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
490✔
2245
        return v, nil
490✔
2246
}
490✔
2247

2248
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2249
        return v
×
2250
}
×
2251

2252
func (v *Bool) isConstant() bool {
3✔
2253
        return true
3✔
2254
}
3✔
2255

2256
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
2257
        return nil
7✔
2258
}
7✔
2259

2260
func (v *Bool) RawValue() interface{} {
939✔
2261
        return v.val
939✔
2262
}
939✔
2263

2264
func (v *Bool) Compare(val TypedValue) (int, error) {
212✔
2265
        if val.IsNull() {
242✔
2266
                return 1, nil
30✔
2267
        }
30✔
2268

2269
        if val.Type() == JSONType {
183✔
2270
                res, err := val.Compare(v)
1✔
2271
                return -res, err
1✔
2272
        }
1✔
2273

2274
        if val.Type() != BooleanType {
181✔
2275
                return 0, ErrNotComparableValues
×
2276
        }
×
2277

2278
        rval := val.RawValue().(bool)
181✔
2279

181✔
2280
        if v.val == rval {
287✔
2281
                return 0, nil
106✔
2282
        }
106✔
2283

2284
        if v.val {
81✔
2285
                return 1, nil
6✔
2286
        }
6✔
2287

2288
        return -1, nil
69✔
2289
}
2290

2291
type Blob struct {
2292
        val []byte
2293
}
2294

2295
func NewBlob(val []byte) *Blob {
286✔
2296
        return &Blob{val: val}
286✔
2297
}
286✔
2298

2299
func (v *Blob) Type() SQLValueType {
53✔
2300
        return BLOBType
53✔
2301
}
53✔
2302

2303
func (v *Blob) IsNull() bool {
2,306✔
2304
        return false
2,306✔
2305
}
2,306✔
2306

2307
func (v *Blob) String() string {
2✔
2308
        return hex.EncodeToString(v.val)
2✔
2309
}
2✔
2310

2311
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2312
        return BLOBType, nil
1✔
2313
}
1✔
2314

2315
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2316
        if t != BLOBType {
3✔
2317
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2318
        }
1✔
2319

2320
        return nil
1✔
2321
}
2322

2323
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
358✔
2324
        return v, nil
358✔
2325
}
358✔
2326

2327
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
370✔
2328
        return v, nil
370✔
2329
}
370✔
2330

2331
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2332
        return v
×
2333
}
×
2334

2335
func (v *Blob) isConstant() bool {
7✔
2336
        return true
7✔
2337
}
7✔
2338

2339
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2340
        return nil
×
2341
}
×
2342

2343
func (v *Blob) RawValue() interface{} {
2,561✔
2344
        return v.val
2,561✔
2345
}
2,561✔
2346

2347
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
2348
        if val.IsNull() {
25✔
2349
                return 1, nil
×
2350
        }
×
2351

2352
        if val.Type() != BLOBType {
25✔
2353
                return 0, ErrNotComparableValues
×
2354
        }
×
2355

2356
        rval := val.RawValue().([]byte)
25✔
2357

25✔
2358
        return bytes.Compare(v.val, rval), nil
25✔
2359
}
2360

2361
type Float64 struct {
2362
        val float64
2363
}
2364

2365
func NewFloat64(val float64) *Float64 {
1,149✔
2366
        return &Float64{val: val}
1,149✔
2367
}
1,149✔
2368

2369
func (v *Float64) Type() SQLValueType {
5,823✔
2370
        return Float64Type
5,823✔
2371
}
5,823✔
2372

2373
func (v *Float64) IsNull() bool {
3,306✔
2374
        return false
3,306✔
2375
}
3,306✔
2376

2377
func (v *Float64) String() string {
3✔
2378
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2379
}
3✔
2380

2381
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2382
        return Float64Type, nil
1✔
2383
}
1✔
2384

2385
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
2386
        if t != Float64Type && t != JSONType {
8✔
2387
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2388
        }
1✔
2389
        return nil
6✔
2390
}
2391

2392
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
92✔
2393
        return v, nil
92✔
2394
}
92✔
2395

2396
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
955✔
2397
        return v, nil
955✔
2398
}
955✔
2399

2400
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2401
        return v
1✔
2402
}
1✔
2403

2404
func (v *Float64) isConstant() bool {
5✔
2405
        return true
5✔
2406
}
5✔
2407

2408
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2409
        return nil
1✔
2410
}
1✔
2411

2412
func (v *Float64) RawValue() interface{} {
18,775✔
2413
        return v.val
18,775✔
2414
}
18,775✔
2415

2416
func (v *Float64) Compare(val TypedValue) (int, error) {
806✔
2417
        if val.Type() == JSONType {
807✔
2418
                res, err := val.Compare(v)
1✔
2419
                return -res, err
1✔
2420
        }
1✔
2421

2422
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
805✔
2423
        if err != nil {
806✔
2424
                return 0, err
1✔
2425
        }
1✔
2426

2427
        if convVal == nil {
807✔
2428
                return 1, nil
3✔
2429
        }
3✔
2430

2431
        rval, ok := convVal.(float64)
801✔
2432
        if !ok {
801✔
2433
                return 0, ErrNotComparableValues
×
2434
        }
×
2435

2436
        if v.val == rval {
827✔
2437
                return 0, nil
26✔
2438
        }
26✔
2439

2440
        if v.val > rval {
1,129✔
2441
                return 1, nil
354✔
2442
        }
354✔
2443

2444
        return -1, nil
421✔
2445
}
2446

2447
type FnCall struct {
2448
        fn     string
2449
        params []ValueExp
2450
}
2451

2452
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2453
        if strings.ToUpper(v.fn) == NowFnCall {
29✔
2454
                return TimestampType, nil
14✔
2455
        }
14✔
2456

2457
        if strings.ToUpper(v.fn) == UUIDFnCall {
1✔
2458
                return UUIDType, nil
×
2459
        }
×
2460

2461
        if strings.ToUpper(v.fn) == JSONTypeOfFnCall {
1✔
2462
                return VarcharType, nil
×
2463
        }
×
2464

2465
        return AnyType, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
1✔
2466
}
2467

2468
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2469
        if strings.ToUpper(v.fn) == NowFnCall {
6✔
2470
                if t != TimestampType {
3✔
2471
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2472
                }
1✔
2473

2474
                return nil
1✔
2475
        }
2476

2477
        if strings.ToUpper(v.fn) == UUIDFnCall {
2✔
2478
                if t != UUIDType {
×
2479
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
×
2480
                }
×
2481

2482
                return nil
×
2483
        }
2484

2485
        if strings.ToUpper(v.fn) == JSONTypeOfFnCall {
2✔
2486
                if t != VarcharType {
×
2487
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
×
2488
                }
×
2489
                return nil
×
2490
        }
2491

2492
        return fmt.Errorf("%w: unkown function %s", ErrIllegalArguments, v.fn)
2✔
2493
}
2494

2495
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
399✔
2496
        ps := make([]ValueExp, len(v.params))
399✔
2497

399✔
2498
        for i, p := range v.params {
699✔
2499
                ps[i], err = p.substitute(params)
300✔
2500
                if err != nil {
300✔
2501
                        return nil, err
×
2502
                }
×
2503
        }
2504

2505
        return &FnCall{
399✔
2506
                fn:     v.fn,
399✔
2507
                params: ps,
399✔
2508
        }, nil
399✔
2509
}
2510

2511
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
399✔
2512
        if strings.ToUpper(v.fn) == NowFnCall {
494✔
2513
                if len(v.params) > 0 {
95✔
2514
                        return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, NowFnCall, len(v.params))
×
2515
                }
×
2516
                return &Timestamp{val: tx.Timestamp().Truncate(time.Microsecond).UTC()}, nil
95✔
2517
        }
2518

2519
        if strings.ToUpper(v.fn) == UUIDFnCall {
307✔
2520
                if len(v.params) > 0 {
3✔
2521
                        return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, UUIDFnCall, len(v.params))
×
2522
                }
×
2523
                return &UUID{val: uuid.New()}, nil
3✔
2524
        }
2525

2526
        if strings.ToUpper(v.fn) == JSONTypeOfFnCall {
601✔
2527
                if len(v.params) != 1 {
300✔
2528
                        return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, JSONTypeOfFnCall, 1, len(v.params))
×
2529
                }
×
2530

2531
                v, err := v.params[0].reduce(tx, row, implicitTable)
300✔
2532
                if err != nil {
300✔
2533
                        return nil, err
×
2534
                }
×
2535

2536
                if v.IsNull() {
300✔
2537
                        return NewNull(AnyType), nil
×
2538
                }
×
2539

2540
                jsonVal, ok := v.(*JSON)
300✔
2541
                if !ok {
300✔
2542
                        return nil, fmt.Errorf("%w: '%s' function expects an argument of type JSON", ErrIllegalArguments, JSONTypeOfFnCall)
×
2543
                }
×
2544
                return NewVarchar(jsonVal.primitiveType()), nil
300✔
2545
        }
2546
        return nil, fmt.Errorf("%w: unkown function %s", ErrIllegalArguments, v.fn)
1✔
2547
}
2548

2549
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2550
        return v
×
2551
}
×
2552

2553
func (v *FnCall) isConstant() bool {
13✔
2554
        return false
13✔
2555
}
13✔
2556

2557
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2558
        return nil
×
2559
}
×
2560

2561
func (v *FnCall) String() string {
1✔
2562
        params := make([]string, len(v.params))
1✔
2563
        for i, p := range v.params {
4✔
2564
                params[i] = p.String()
3✔
2565
        }
3✔
2566
        return v.fn + "(" + strings.Join(params, ",") + ")"
1✔
2567
}
2568

2569
type Cast struct {
2570
        val ValueExp
2571
        t   SQLValueType
2572
}
2573

2574
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
11✔
2575
        _, err := c.val.inferType(cols, params, implicitTable)
11✔
2576
        if err != nil {
12✔
2577
                return AnyType, err
1✔
2578
        }
1✔
2579

2580
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2581

2582
        return c.t, nil
10✔
2583
}
2584

2585
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2586
        if c.t != t {
×
2587
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2588
        }
×
2589

2590
        return nil
×
2591
}
2592

2593
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
48✔
2594
        val, err := c.val.substitute(params)
48✔
2595
        if err != nil {
48✔
2596
                return nil, err
×
2597
        }
×
2598
        c.val = val
48✔
2599
        return c, nil
48✔
2600
}
2601

2602
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
48✔
2603
        val, err := c.val.reduce(tx, row, implicitTable)
48✔
2604
        if err != nil {
48✔
2605
                return nil, err
×
2606
        }
×
2607

2608
        conv, err := getConverter(val.Type(), c.t)
48✔
2609
        if conv == nil {
51✔
2610
                return nil, err
3✔
2611
        }
3✔
2612

2613
        return conv(val)
45✔
2614
}
2615

2616
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2617
        return &Cast{
×
2618
                val: c.val.reduceSelectors(row, implicitTable),
×
2619
                t:   c.t,
×
2620
        }
×
2621
}
×
2622

2623
func (c *Cast) isConstant() bool {
7✔
2624
        return c.val.isConstant()
7✔
2625
}
7✔
2626

2627
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2628
        return nil
×
2629
}
×
2630

2631
func (c *Cast) String() string {
1✔
2632
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2633
}
1✔
2634

2635
type Param struct {
2636
        id  string
2637
        pos int
2638
}
2639

2640
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
59✔
2641
        t, ok := params[v.id]
59✔
2642
        if !ok {
116✔
2643
                params[v.id] = AnyType
57✔
2644
                return AnyType, nil
57✔
2645
        }
57✔
2646

2647
        return t, nil
2✔
2648
}
2649

2650
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2651
        currT, ok := params[v.id]
76✔
2652
        if ok && currT != t && currT != AnyType {
80✔
2653
                return ErrInferredMultipleTypes
4✔
2654
        }
4✔
2655

2656
        params[v.id] = t
72✔
2657

72✔
2658
        return nil
72✔
2659
}
2660

2661
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
4,769✔
2662
        val, ok := params[p.id]
4,769✔
2663
        if !ok {
4,831✔
2664
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2665
        }
62✔
2666

2667
        if val == nil {
4,740✔
2668
                return &NullValue{t: AnyType}, nil
33✔
2669
        }
33✔
2670

2671
        switch v := val.(type) {
4,674✔
2672
        case bool:
96✔
2673
                {
192✔
2674
                        return &Bool{val: v}, nil
96✔
2675
                }
96✔
2676
        case string:
1,039✔
2677
                {
2,078✔
2678
                        return &Varchar{val: v}, nil
1,039✔
2679
                }
1,039✔
2680
        case int:
1,651✔
2681
                {
3,302✔
2682
                        return &Integer{val: int64(v)}, nil
1,651✔
2683
                }
1,651✔
2684
        case uint:
×
2685
                {
×
2686
                        return &Integer{val: int64(v)}, nil
×
2687
                }
×
2688
        case uint64:
34✔
2689
                {
68✔
2690
                        return &Integer{val: int64(v)}, nil
34✔
2691
                }
34✔
2692
        case int64:
125✔
2693
                {
250✔
2694
                        return &Integer{val: v}, nil
125✔
2695
                }
125✔
2696
        case []byte:
14✔
2697
                {
28✔
2698
                        return &Blob{val: v}, nil
14✔
2699
                }
14✔
2700
        case time.Time:
850✔
2701
                {
1,700✔
2702
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
850✔
2703
                }
850✔
2704
        case float64:
864✔
2705
                {
1,728✔
2706
                        return &Float64{val: v}, nil
864✔
2707
                }
864✔
2708
        }
2709
        return nil, ErrUnsupportedParameter
1✔
2710
}
2711

2712
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2713
        return nil, ErrUnexpected
×
2714
}
×
2715

2716
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2717
        return p
×
2718
}
×
2719

2720
func (p *Param) isConstant() bool {
130✔
2721
        return true
130✔
2722
}
130✔
2723

2724
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
5✔
2725
        return nil
5✔
2726
}
5✔
2727

2728
func (v *Param) String() string {
2✔
2729
        return "@" + v.id
2✔
2730
}
2✔
2731

2732
type Comparison int
2733

2734
const (
2735
        EqualTo Comparison = iota
2736
        LowerThan
2737
        LowerOrEqualTo
2738
        GreaterThan
2739
        GreaterOrEqualTo
2740
)
2741

2742
type DataSource interface {
2743
        SQLStmt
2744
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
2745
        Alias() string
2746
}
2747

2748
type SelectStmt struct {
2749
        distinct  bool
2750
        selectors []Selector
2751
        ds        DataSource
2752
        indexOn   []string
2753
        joins     []*JoinSpec
2754
        where     ValueExp
2755
        groupBy   []*ColSelector
2756
        having    ValueExp
2757
        orderBy   []*OrdCol
2758
        limit     ValueExp
2759
        offset    ValueExp
2760
        as        string
2761
}
2762

2763
func NewSelectStmt(
2764
        selectors []Selector,
2765
        ds DataSource,
2766
        where ValueExp,
2767
        orderBy []*OrdCol,
2768
        limit ValueExp,
2769
        offset ValueExp,
2770
) *SelectStmt {
71✔
2771
        return &SelectStmt{
71✔
2772
                selectors: selectors,
71✔
2773
                ds:        ds,
71✔
2774
                where:     where,
71✔
2775
                orderBy:   orderBy,
71✔
2776
                limit:     limit,
71✔
2777
                offset:    offset,
71✔
2778
        }
71✔
2779
}
71✔
2780

2781
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
52✔
2782
        _, err := stmt.execAt(ctx, tx, nil)
52✔
2783
        if err != nil {
52✔
2784
                return err
×
2785
        }
×
2786

2787
        // TODO (jeroiraz) may be optimized so to resolve the query statement just once
2788
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
52✔
2789
        if err != nil {
52✔
2790
                return err
×
2791
        }
×
2792
        defer rowReader.Close()
52✔
2793

52✔
2794
        return rowReader.InferParameters(ctx, params)
52✔
2795
}
2796

2797
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
545✔
2798
        if stmt.groupBy == nil && stmt.having != nil {
546✔
2799
                return nil, ErrHavingClauseRequiresGroupClause
1✔
2800
        }
1✔
2801

2802
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
629✔
2803
                for _, sel := range stmt.selectors {
227✔
2804
                        _, isAgg := sel.(*AggColSelector)
142✔
2805
                        if !isAgg && !stmt.groupByContains(sel) {
144✔
2806
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
2807
                        }
2✔
2808
                }
2809
        }
2810

2811
        if len(stmt.orderBy) > 0 {
678✔
2812
                for _, col := range stmt.orderBy {
298✔
2813
                        sel := col.sel
162✔
2814
                        _, isAgg := sel.(*AggColSelector)
162✔
2815
                        if (isAgg && !stmt.containsSelector(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
164✔
2816
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
2817
                        }
2✔
2818
                }
2819
        }
2820
        return tx, nil
540✔
2821
}
2822

2823
func (stmt *SelectStmt) containsSelector(s Selector) bool {
4✔
2824
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
2825

4✔
2826
        for _, sel := range stmt.selectors {
12✔
2827
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
11✔
2828
                        return true
3✔
2829
                }
3✔
2830
        }
2831
        return false
1✔
2832
}
2833

2834
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
2835
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
2836

57✔
2837
        for _, colSel := range stmt.groupBy {
137✔
2838
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
134✔
2839
                        return true
54✔
2840
                }
54✔
2841
        }
2842
        return false
3✔
2843
}
2844

2845
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
768✔
2846
        scanSpecs, err := stmt.genScanSpecs(tx, params)
768✔
2847
        if err != nil {
782✔
2848
                return nil, err
14✔
2849
        }
14✔
2850

2851
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
754✔
2852
        if err != nil {
755✔
2853
                return nil, err
1✔
2854
        }
1✔
2855
        defer func() {
1,506✔
2856
                if err != nil {
759✔
2857
                        rowReader.Close()
6✔
2858
                }
6✔
2859
        }()
2860

2861
        if stmt.joins != nil {
765✔
2862
                var jointRowReader *jointRowReader
12✔
2863
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
12✔
2864
                if err != nil {
13✔
2865
                        return nil, err
1✔
2866
                }
1✔
2867
                rowReader = jointRowReader
11✔
2868
        }
2869

2870
        if stmt.where != nil {
1,162✔
2871
                rowReader = newConditionalRowReader(rowReader, stmt.where)
410✔
2872
        }
410✔
2873

2874
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
832✔
2875
                if len(scanSpecs.groupBySortColumns) > 0 {
92✔
2876
                        var sortRowReader *sortRowReader
12✔
2877
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortColumns)
12✔
2878
                        if err != nil {
12✔
2879
                                return nil, err
×
2880
                        }
×
2881
                        rowReader = sortRowReader
12✔
2882
                }
2883

2884
                var groupedRowReader *groupedRowReader
80✔
2885
                groupedRowReader, err = newGroupedRowReader(rowReader, stmt.selectors, stmt.groupBy)
80✔
2886
                if err != nil {
82✔
2887
                        return nil, err
2✔
2888
                }
2✔
2889
                rowReader = groupedRowReader
78✔
2890

78✔
2891
                if stmt.having != nil {
82✔
2892
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
2893
                }
4✔
2894
        }
2895

2896
        if len(scanSpecs.orderBySortCols) > 0 {
781✔
2897
                var sortRowReader *sortRowReader
31✔
2898
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
31✔
2899
                if err != nil {
32✔
2900
                        return nil, err
1✔
2901
                }
1✔
2902
                rowReader = sortRowReader
30✔
2903
        }
2904

2905
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.selectors)
749✔
2906
        if err != nil {
750✔
2907
                return nil, err
1✔
2908
        }
1✔
2909
        rowReader = projectedRowReader
748✔
2910

748✔
2911
        if stmt.distinct {
756✔
2912
                var distinctRowReader *distinctRowReader
8✔
2913
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
8✔
2914
                if err != nil {
9✔
2915
                        return nil, err
1✔
2916
                }
1✔
2917
                rowReader = distinctRowReader
7✔
2918
        }
2919

2920
        if stmt.offset != nil {
797✔
2921
                var offset int
50✔
2922
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
2923
                if err != nil {
50✔
2924
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
2925
                }
×
2926

2927
                rowReader = newOffsetRowReader(rowReader, offset)
50✔
2928
        }
2929

2930
        if stmt.limit != nil {
843✔
2931
                var limit int
96✔
2932
                limit, err = evalExpAsInt(tx, stmt.limit, params)
96✔
2933
                if err != nil {
96✔
2934
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
2935
                }
×
2936

2937
                if limit < 0 {
96✔
2938
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
2939
                }
×
2940

2941
                if limit > 0 {
138✔
2942
                        rowReader = newLimitRowReader(rowReader, limit)
42✔
2943
                }
42✔
2944
        }
2945

2946
        return rowReader, nil
747✔
2947
}
2948

2949
func (stmt *SelectStmt) rearrangeOrdColumns(groupByCols, orderByCols []*OrdCol) ([]*OrdCol, []*OrdCol) {
754✔
2950
        if len(groupByCols) > 0 && len(orderByCols) > 0 && !ordColumnsHaveAggregations(orderByCols) {
760✔
2951
                if ordColsHasPrefix(orderByCols, groupByCols, stmt.Alias()) {
8✔
2952
                        return orderByCols, nil
2✔
2953
                }
2✔
2954

2955
                if ordColsHasPrefix(groupByCols, orderByCols, stmt.Alias()) {
5✔
2956
                        for i := range orderByCols {
2✔
2957
                                groupByCols[i].descOrder = orderByCols[i].descOrder
1✔
2958
                        }
1✔
2959
                        return groupByCols, nil
1✔
2960
                }
2961
        }
2962
        return groupByCols, orderByCols
751✔
2963
}
2964

2965
func ordColsHasPrefix(cols, prefix []*OrdCol, table string) bool {
10✔
2966
        if len(prefix) > len(cols) {
12✔
2967
                return false
2✔
2968
        }
2✔
2969

2970
        for i := range prefix {
17✔
2971
                if EncodeSelector(prefix[i].sel.resolve(table)) != EncodeSelector(cols[i].sel.resolve(table)) {
14✔
2972
                        return false
5✔
2973
                }
5✔
2974
        }
2975
        return true
3✔
2976
}
2977

2978
func (stmt *SelectStmt) groupByOrdColumns() []*OrdCol {
768✔
2979
        groupByCols := stmt.groupBy
768✔
2980

768✔
2981
        ordCols := make([]*OrdCol, 0, len(groupByCols))
768✔
2982
        for _, col := range groupByCols {
815✔
2983
                ordCols = append(ordCols, &OrdCol{sel: col})
47✔
2984
        }
47✔
2985
        return ordCols
768✔
2986
}
2987

2988
func ordColumnsHaveAggregations(cols []*OrdCol) bool {
7✔
2989
        for _, ordCol := range cols {
17✔
2990
                if _, isAgg := ordCol.sel.(*AggColSelector); isAgg {
11✔
2991
                        return true
1✔
2992
                }
1✔
2993
        }
2994
        return false
6✔
2995
}
2996

2997
func (stmt *SelectStmt) containsAggregations() bool {
1,296✔
2998
        for _, sel := range stmt.selectors {
2,770✔
2999
                _, isAgg := sel.(*AggColSelector)
1,474✔
3000
                if isAgg {
1,637✔
3001
                        return true
163✔
3002
                }
163✔
3003
        }
3004
        return false
1,133✔
3005
}
3006

3007
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
146✔
3008
        offset, err := exp.substitute(params)
146✔
3009
        if err != nil {
146✔
3010
                return 0, err
×
3011
        }
×
3012

3013
        texp, err := offset.reduce(tx, nil, "")
146✔
3014
        if err != nil {
146✔
3015
                return 0, err
×
3016
        }
×
3017

3018
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
146✔
3019
        if err != nil {
146✔
3020
                return 0, ErrInvalidValue
×
3021
        }
×
3022

3023
        num, ok := convVal.(int64)
146✔
3024
        if !ok {
146✔
3025
                return 0, ErrInvalidValue
×
3026
        }
×
3027

3028
        if num > math.MaxInt32 {
146✔
3029
                return 0, ErrInvalidValue
×
3030
        }
×
3031

3032
        return int(num), nil
146✔
3033
}
3034

3035
func (stmt *SelectStmt) Alias() string {
167✔
3036
        if stmt.as == "" {
333✔
3037
                return stmt.ds.Alias()
166✔
3038
        }
166✔
3039

3040
        return stmt.as
1✔
3041
}
3042

3043
func (stmt *SelectStmt) hasTxMetadata() bool {
702✔
3044
        for _, sel := range stmt.selectors {
1,458✔
3045
                switch s := sel.(type) {
756✔
3046
                case *ColSelector:
650✔
3047
                        if s.col == txMetadataCol {
651✔
3048
                                return true
1✔
3049
                        }
1✔
3050
                case *JSONSelector:
13✔
3051
                        if s.ColSelector.col == txMetadataCol {
16✔
3052
                                return true
3✔
3053
                        }
3✔
3054
                }
3055
        }
3056
        return false
698✔
3057
}
3058

3059
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
768✔
3060
        groupByCols, orderByCols := stmt.groupByOrdColumns(), stmt.orderBy
768✔
3061

768✔
3062
        tableRef, isTableRef := stmt.ds.(*tableRef)
768✔
3063
        if !isTableRef {
820✔
3064
                groupByCols, orderByCols = stmt.rearrangeOrdColumns(groupByCols, orderByCols)
52✔
3065

52✔
3066
                return &ScanSpecs{
52✔
3067
                        groupBySortColumns: groupByCols,
52✔
3068
                        orderBySortCols:    orderByCols,
52✔
3069
                }, nil
52✔
3070
        }
52✔
3071

3072
        table, err := tableRef.referencedTable(tx)
716✔
3073
        if err != nil {
728✔
3074
                return nil, err
12✔
3075
        }
12✔
3076

3077
        rangesByColID := make(map[uint32]*typedValueRange)
704✔
3078
        if stmt.where != nil {
1,103✔
3079
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
399✔
3080
                if err != nil {
401✔
3081
                        return nil, err
2✔
3082
                }
2✔
3083
        }
3084

3085
        preferredIndex, err := stmt.getPreferredIndex(table)
702✔
3086
        if err != nil {
702✔
3087
                return nil, err
×
3088
        }
×
3089

3090
        var sortingIndex *Index
702✔
3091
        if preferredIndex == nil {
1,374✔
3092
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
672✔
3093
        } else {
702✔
3094
                sortingIndex = preferredIndex
30✔
3095
        }
30✔
3096

3097
        if sortingIndex == nil {
1,284✔
3098
                sortingIndex = table.primaryIndex
582✔
3099
        }
582✔
3100

3101
        if tableRef.history && !sortingIndex.IsPrimary() {
702✔
3102
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3103
        }
×
3104

3105
        var descOrder bool
702✔
3106
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
719✔
3107
                groupByCols = nil
17✔
3108
        }
17✔
3109

3110
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
801✔
3111
                descOrder = orderByCols[0].descOrder
99✔
3112
                orderByCols = nil
99✔
3113
        }
99✔
3114

3115
        groupByCols, orderByCols = stmt.rearrangeOrdColumns(groupByCols, orderByCols)
702✔
3116

702✔
3117
        return &ScanSpecs{
702✔
3118
                Index:              sortingIndex,
702✔
3119
                rangesByColID:      rangesByColID,
702✔
3120
                IncludeHistory:     tableRef.history,
702✔
3121
                IncludeTxMetadata:  stmt.hasTxMetadata(),
702✔
3122
                DescOrder:          descOrder,
702✔
3123
                groupBySortColumns: groupByCols,
702✔
3124
                orderBySortCols:    orderByCols,
702✔
3125
        }, nil
702✔
3126
}
3127

3128
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdCol, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
672✔
3129
        sortCols := groupByCols
672✔
3130
        if len(sortCols) == 0 {
1,318✔
3131
                sortCols = orderByCols
646✔
3132
        }
646✔
3133

3134
        if len(sortCols) == 0 {
1,226✔
3135
                return nil
554✔
3136
        }
554✔
3137

3138
        for _, idx := range table.indexes {
322✔
3139
                if idx.coversOrdCols(sortCols, rangesByColId) {
294✔
3140
                        return idx
90✔
3141
                }
90✔
3142
        }
3143
        return nil
28✔
3144
}
3145

3146
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
702✔
3147
        if len(stmt.indexOn) == 0 {
1,374✔
3148
                return nil, nil
672✔
3149
        }
672✔
3150

3151
        cols := make([]*Column, len(stmt.indexOn))
30✔
3152
        for i, colName := range stmt.indexOn {
80✔
3153
                col, err := table.GetColumnByName(colName)
50✔
3154
                if err != nil {
50✔
3155
                        return nil, err
×
3156
                }
×
3157

3158
                cols[i] = col
50✔
3159
        }
3160
        return table.GetIndexByName(indexName(table.name, cols))
30✔
3161
}
3162

3163
type UnionStmt struct {
3164
        distinct    bool
3165
        left, right DataSource
3166
}
3167

3168
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3169
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
3170
        if err != nil {
1✔
3171
                return err
×
3172
        }
×
3173
        return stmt.right.inferParameters(ctx, tx, params)
1✔
3174
}
3175

3176
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3177
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
3178
        if err != nil {
9✔
3179
                return tx, err
×
3180
        }
×
3181

3182
        return stmt.right.execAt(ctx, tx, params)
9✔
3183
}
3184

3185
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3186
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
11✔
3187
        if err != nil {
12✔
3188
                return nil, err
1✔
3189
        }
1✔
3190
        defer func() {
20✔
3191
                if err != nil {
14✔
3192
                        leftRowReader.Close()
4✔
3193
                }
4✔
3194
        }()
3195

3196
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
10✔
3197
        if err != nil {
11✔
3198
                return nil, err
1✔
3199
        }
1✔
3200
        defer func() {
18✔
3201
                if err != nil {
12✔
3202
                        rightRowReader.Close()
3✔
3203
                }
3✔
3204
        }()
3205

3206
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
9✔
3207
        if err != nil {
12✔
3208
                return nil, err
3✔
3209
        }
3✔
3210

3211
        return rowReader, nil
6✔
3212
}
3213

3214
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3215
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
11✔
3216
        if err != nil {
16✔
3217
                return nil, err
5✔
3218
        }
5✔
3219
        defer func() {
12✔
3220
                if err != nil {
7✔
3221
                        rowReader.Close()
1✔
3222
                }
1✔
3223
        }()
3224

3225
        if stmt.distinct {
11✔
3226
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
5✔
3227
                if err != nil {
6✔
3228
                        return nil, err
1✔
3229
                }
1✔
3230
                rowReader = distinctReader
4✔
3231
        }
3232

3233
        return rowReader, nil
5✔
3234
}
3235

3236
func (stmt *UnionStmt) Alias() string {
×
3237
        return ""
×
3238
}
×
3239

3240
func NewTableRef(table string, as string) *tableRef {
179✔
3241
        return &tableRef{
179✔
3242
                table: table,
179✔
3243
                as:    as,
179✔
3244
        }
179✔
3245
}
179✔
3246

3247
type tableRef struct {
3248
        table   string
3249
        history bool
3250
        period  period
3251
        as      string
3252
}
3253

3254
type period struct {
3255
        start *openPeriod
3256
        end   *openPeriod
3257
}
3258

3259
type openPeriod struct {
3260
        inclusive bool
3261
        instant   periodInstant
3262
}
3263

3264
type periodInstant struct {
3265
        exp         ValueExp
3266
        instantType instantType
3267
}
3268

3269
type instantType = int
3270

3271
const (
3272
        txInstant instantType = iota
3273
        timeInstant
3274
)
3275

3276
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
81✔
3277
        exp, err := i.exp.substitute(params)
81✔
3278
        if err != nil {
81✔
3279
                return 0, err
×
3280
        }
×
3281

3282
        instantVal, err := exp.reduce(tx, nil, "")
81✔
3283
        if err != nil {
83✔
3284
                return 0, err
2✔
3285
        }
2✔
3286

3287
        if i.instantType == txInstant {
124✔
3288
                txID, ok := instantVal.RawValue().(int64)
45✔
3289
                if !ok {
45✔
3290
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3291
                }
×
3292

3293
                if txID <= 0 {
52✔
3294
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3295
                }
7✔
3296

3297
                if inclusive {
61✔
3298
                        return uint64(txID), nil
23✔
3299
                }
23✔
3300

3301
                if asc {
26✔
3302
                        return uint64(txID + 1), nil
11✔
3303
                }
11✔
3304

3305
                if txID <= 1 {
5✔
3306
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3307
                }
1✔
3308

3309
                return uint64(txID - 1), nil
3✔
3310
        } else {
34✔
3311

34✔
3312
                var ts time.Time
34✔
3313

34✔
3314
                if instantVal.Type() == TimestampType {
67✔
3315
                        ts = instantVal.RawValue().(time.Time)
33✔
3316
                } else {
34✔
3317
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
3318
                        if err != nil {
1✔
3319
                                return 0, err
×
3320
                        }
×
3321

3322
                        tval, err := conv(instantVal)
1✔
3323
                        if err != nil {
1✔
3324
                                return 0, err
×
3325
                        }
×
3326

3327
                        ts = tval.RawValue().(time.Time)
1✔
3328
                }
3329

3330
                sts := ts
34✔
3331

34✔
3332
                if asc {
57✔
3333
                        if !inclusive {
34✔
3334
                                sts = sts.Add(1 * time.Second)
11✔
3335
                        }
11✔
3336

3337
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
23✔
3338
                        if err != nil {
34✔
3339
                                return 0, err
11✔
3340
                        }
11✔
3341

3342
                        return txHdr.ID, nil
12✔
3343
                }
3344

3345
                if !inclusive {
11✔
3346
                        sts = sts.Add(-1 * time.Second)
×
3347
                }
×
3348

3349
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
3350
                if err != nil {
11✔
3351
                        return 0, err
×
3352
                }
×
3353

3354
                return txHdr.ID, nil
11✔
3355
        }
3356
}
3357

3358
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
3,356✔
3359
        table, err := tx.catalog.GetTableByName(stmt.table)
3,356✔
3360
        if err != nil {
3,372✔
3361
                return nil, err
16✔
3362
        }
16✔
3363

3364
        return table, nil
3,340✔
3365
}
3366

3367
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3368
        return nil
1✔
3369
}
1✔
3370

3371
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3372
        return tx, nil
×
3373
}
×
3374

3375
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
718✔
3376
        if tx == nil {
718✔
3377
                return nil, ErrIllegalArguments
×
3378
        }
×
3379

3380
        table, err := stmt.referencedTable(tx)
718✔
3381
        if err != nil {
718✔
3382
                return nil, err
×
3383
        }
×
3384

3385
        return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
718✔
3386
}
3387

3388
func (stmt *tableRef) Alias() string {
563✔
3389
        if stmt.as == "" {
1,080✔
3390
                return stmt.table
517✔
3391
        }
517✔
3392
        return stmt.as
46✔
3393
}
3394

3395
type JoinSpec struct {
3396
        joinType JoinType
3397
        ds       DataSource
3398
        cond     ValueExp
3399
        indexOn  []string
3400
}
3401

3402
type OrdCol struct {
3403
        sel       Selector
3404
        descOrder bool
3405
}
3406

3407
func NewOrdCol(table string, col string, descOrder bool) *OrdCol {
1✔
3408
        return &OrdCol{
1✔
3409
                sel:       NewColSelector(table, col),
1✔
3410
                descOrder: descOrder,
1✔
3411
        }
1✔
3412
}
1✔
3413

3414
type Selector interface {
3415
        ValueExp
3416
        resolve(implicitTable string) (aggFn, table, col string)
3417
        alias() string
3418
        setAlias(alias string)
3419
}
3420

3421
type ColSelector struct {
3422
        table string
3423
        col   string
3424
        as    string
3425
}
3426

3427
func NewColSelector(table, col string) *ColSelector {
125✔
3428
        return &ColSelector{
125✔
3429
                table: table,
125✔
3430
                col:   col,
125✔
3431
        }
125✔
3432
}
125✔
3433

3434
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
386,167✔
3435
        table = implicitTable
386,167✔
3436
        if sel.table != "" {
496,978✔
3437
                table = sel.table
110,811✔
3438
        }
110,811✔
3439
        return "", table, sel.col
386,167✔
3440
}
3441

3442
func (sel *ColSelector) alias() string {
125,672✔
3443
        if sel.as == "" {
251,252✔
3444
                return sel.col
125,580✔
3445
        }
125,580✔
3446

3447
        return sel.as
92✔
3448
}
3449

3450
func (sel *ColSelector) setAlias(alias string) {
698✔
3451
        sel.as = alias
698✔
3452
}
698✔
3453

3454
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
525✔
3455
        _, table, col := sel.resolve(implicitTable)
525✔
3456
        encSel := EncodeSelector("", table, col)
525✔
3457

525✔
3458
        desc, ok := cols[encSel]
525✔
3459
        if !ok {
528✔
3460
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
3461
        }
3✔
3462

3463
        return desc.Type, nil
522✔
3464
}
3465

3466
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
3467
        _, table, col := sel.resolve(implicitTable)
15✔
3468
        encSel := EncodeSelector("", table, col)
15✔
3469

15✔
3470
        desc, ok := cols[encSel]
15✔
3471
        if !ok {
17✔
3472
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
3473
        }
2✔
3474

3475
        if desc.Type != t {
16✔
3476
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
3477
        }
3✔
3478

3479
        return nil
10✔
3480
}
3481

3482
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
3,011✔
3483
        return sel, nil
3,011✔
3484
}
3,011✔
3485

3486
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
319,207✔
3487
        if row == nil {
319,208✔
3488
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
3489
        }
1✔
3490

3491
        aggFn, table, col := sel.resolve(implicitTable)
319,206✔
3492

319,206✔
3493
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
319,206✔
3494
        if !ok {
319,213✔
3495
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
3496
        }
7✔
3497

3498
        return v, nil
319,199✔
3499
}
3500

3501
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
339✔
3502
        aggFn, table, col := sel.resolve(implicitTable)
339✔
3503

339✔
3504
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
339✔
3505
        if !ok {
503✔
3506
                return sel
164✔
3507
        }
164✔
3508

3509
        return v
175✔
3510
}
3511

3512
func (sel *ColSelector) isConstant() bool {
12✔
3513
        return false
12✔
3514
}
12✔
3515

3516
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
3517
        return nil
11✔
3518
}
11✔
3519

3520
func (sel *ColSelector) String() string {
27✔
3521
        return sel.col
27✔
3522
}
27✔
3523

3524
type AggColSelector struct {
3525
        aggFn AggregateFn
3526
        table string
3527
        col   string
3528
        as    string
3529
}
3530

3531
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
3532
        return &AggColSelector{
16✔
3533
                aggFn: aggFn,
16✔
3534
                table: table,
16✔
3535
                col:   col,
16✔
3536
        }
16✔
3537
}
16✔
3538

3539
func EncodeSelector(aggFn, table, col string) string {
602,936✔
3540
        return aggFn + "(" + table + "." + col + ")"
602,936✔
3541
}
602,936✔
3542

3543
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,581✔
3544
        table = implicitTable
1,581✔
3545
        if sel.table != "" {
1,712✔
3546
                table = sel.table
131✔
3547
        }
131✔
3548

3549
        return sel.aggFn, table, sel.col
1,581✔
3550
}
3551

3552
func (sel *AggColSelector) alias() string {
435✔
3553
        return sel.as
435✔
3554
}
435✔
3555

3556
func (sel *AggColSelector) setAlias(alias string) {
106✔
3557
        sel.as = alias
106✔
3558
}
106✔
3559

3560
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
33✔
3561
        if sel.aggFn == COUNT {
50✔
3562
                return IntegerType, nil
17✔
3563
        }
17✔
3564

3565
        colSelector := &ColSelector{table: sel.table, col: sel.col}
16✔
3566

16✔
3567
        if sel.aggFn == SUM || sel.aggFn == AVG {
23✔
3568
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
3569
                if err != nil {
7✔
3570
                        return AnyType, err
×
3571
                }
×
3572

3573
                if t != IntegerType && t != Float64Type {
7✔
3574
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
3575

×
3576
                }
×
3577

3578
                return t, nil
7✔
3579
        }
3580

3581
        return colSelector.inferType(cols, params, implicitTable)
9✔
3582
}
3583

3584
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
3585
        if sel.aggFn == COUNT {
10✔
3586
                if t != IntegerType {
3✔
3587
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
3588
                }
1✔
3589
                return nil
1✔
3590
        }
3591

3592
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
3593

6✔
3594
        if sel.aggFn == SUM || sel.aggFn == AVG {
10✔
3595
                if t != IntegerType && t != Float64Type {
5✔
3596
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
3597
                }
1✔
3598
        }
3599

3600
        return colSelector.requiresType(t, cols, params, implicitTable)
5✔
3601
}
3602

3603
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
25✔
3604
        return sel, nil
25✔
3605
}
25✔
3606

3607
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
459✔
3608
        if row == nil {
460✔
3609
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
3610
        }
1✔
3611

3612
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
458✔
3613
        if !ok {
459✔
3614
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
3615
        }
1✔
3616
        return v, nil
457✔
3617
}
3618

3619
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
3620
        return sel
×
3621
}
×
3622

3623
func (sel *AggColSelector) isConstant() bool {
1✔
3624
        return false
1✔
3625
}
1✔
3626

3627
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
3628
        return nil
×
3629
}
×
3630

NEW
3631
func (sel *AggColSelector) String() string {
×
NEW
3632
        return sel.aggFn + "(" + sel.col + ")"
×
NEW
3633
}
×
3634

3635
type NumExp struct {
3636
        op          NumOperator
3637
        left, right ValueExp
3638
}
3639

3640
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
3641
        // First step - check if we can infer the type of sub-expressions
6✔
3642
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
6✔
3643
        if err != nil {
6✔
3644
                return AnyType, err
×
3645
        }
×
3646
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type {
6✔
3647
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
3648
        }
×
3649

3650
        tright, err := bexp.right.inferType(cols, params, implicitTable)
6✔
3651
        if err != nil {
6✔
3652
                return AnyType, err
×
3653
        }
×
3654
        if tright != AnyType && tright != IntegerType && tright != Float64Type {
8✔
3655
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
3656
        }
2✔
3657

3658
        if tleft == IntegerType && tright == IntegerType {
6✔
3659
                // Both sides are integer types - the result is also integer
2✔
3660
                return IntegerType, nil
2✔
3661
        }
2✔
3662

3663
        if tleft != AnyType && tright != AnyType {
2✔
3664
                // Both sides have concrete types but at least one of them is float
×
3665
                return Float64Type, nil
×
3666
        }
×
3667

3668
        // Both sides are ambiguous
3669
        return AnyType, nil
2✔
3670
}
3671

3672
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
3673
        ret := make(map[string]SQLValueType, len(params))
11✔
3674
        for k, v := range params {
15✔
3675
                ret[k] = v
4✔
3676
        }
4✔
3677
        return ret
11✔
3678
}
3679

3680
func restoreParams(params, restore map[string]SQLValueType) {
2✔
3681
        for k := range params {
2✔
3682
                delete(params, k)
×
3683
        }
×
3684
        for k, v := range restore {
2✔
3685
                params[k] = v
×
3686
        }
×
3687
}
3688

3689
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
3690
        if t != IntegerType && t != Float64Type {
8✔
3691
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
3692
        }
1✔
3693

3694
        floatArgs := 2
6✔
3695
        paramsOrig := copyParams(params)
6✔
3696
        err := bexp.left.requiresType(t, cols, params, implicitTable)
6✔
3697
        if err != nil && t == Float64Type {
7✔
3698
                restoreParams(params, paramsOrig)
1✔
3699
                floatArgs--
1✔
3700
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
3701
        }
1✔
3702
        if err != nil {
7✔
3703
                return err
1✔
3704
        }
1✔
3705

3706
        paramsOrig = copyParams(params)
5✔
3707
        err = bexp.right.requiresType(t, cols, params, implicitTable)
5✔
3708
        if err != nil && t == Float64Type {
6✔
3709
                restoreParams(params, paramsOrig)
1✔
3710
                floatArgs--
1✔
3711
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
3712
        }
1✔
3713
        if err != nil {
7✔
3714
                return err
2✔
3715
        }
2✔
3716

3717
        if t == Float64Type && floatArgs == 0 {
3✔
3718
                // Currently this case requires explicit float cast
×
3719
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
3720
        }
×
3721

3722
        return nil
3✔
3723
}
3724

3725
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
154✔
3726
        rlexp, err := bexp.left.substitute(params)
154✔
3727
        if err != nil {
154✔
3728
                return nil, err
×
3729
        }
×
3730

3731
        rrexp, err := bexp.right.substitute(params)
154✔
3732
        if err != nil {
154✔
3733
                return nil, err
×
3734
        }
×
3735

3736
        bexp.left = rlexp
154✔
3737
        bexp.right = rrexp
154✔
3738

154✔
3739
        return bexp, nil
154✔
3740
}
3741

3742
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
153✔
3743
        vl, err := bexp.left.reduce(tx, row, implicitTable)
153✔
3744
        if err != nil {
153✔
3745
                return nil, err
×
3746
        }
×
3747

3748
        vr, err := bexp.right.reduce(tx, row, implicitTable)
153✔
3749
        if err != nil {
153✔
3750
                return nil, err
×
3751
        }
×
3752

3753
        vl = unwrapJSON(vl)
153✔
3754
        vr = unwrapJSON(vr)
153✔
3755

153✔
3756
        return applyNumOperator(bexp.op, vl, vr)
153✔
3757
}
3758

3759
func unwrapJSON(v TypedValue) TypedValue {
306✔
3760
        if jsonVal, ok := v.(*JSON); ok {
406✔
3761
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
200✔
3762
                        return sv
100✔
3763
                }
100✔
3764
        }
3765
        return v
206✔
3766
}
3767

3768
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3769
        return &NumExp{
1✔
3770
                op:    bexp.op,
1✔
3771
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
3772
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
3773
        }
1✔
3774
}
1✔
3775

3776
func (bexp *NumExp) isConstant() bool {
5✔
3777
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
3778
}
5✔
3779

3780
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
3✔
3781
        return nil
3✔
3782
}
3✔
3783

3784
func (bexp *NumExp) String() string {
9✔
3785
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
9✔
3786
}
9✔
3787

3788
type NotBoolExp struct {
3789
        exp ValueExp
3790
}
3791

3792
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
2✔
3793
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
2✔
3794
        if err != nil {
2✔
3795
                return AnyType, err
×
3796
        }
×
3797

3798
        return BooleanType, nil
2✔
3799
}
3800

3801
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
3802
        if t != BooleanType {
7✔
3803
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
3804
        }
1✔
3805

3806
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
5✔
3807
}
3808

3809
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
3810
        rexp, err := bexp.exp.substitute(params)
22✔
3811
        if err != nil {
22✔
3812
                return nil, err
×
3813
        }
×
3814

3815
        bexp.exp = rexp
22✔
3816

22✔
3817
        return bexp, nil
22✔
3818
}
3819

3820
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
3821
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
3822
        if err != nil {
22✔
3823
                return nil, err
×
3824
        }
×
3825

3826
        r, isBool := v.RawValue().(bool)
22✔
3827
        if !isBool {
22✔
3828
                return nil, ErrInvalidCondition
×
3829
        }
×
3830

3831
        return &Bool{val: !r}, nil
22✔
3832
}
3833

3834
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
3835
        return &NotBoolExp{
×
3836
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
3837
        }
×
3838
}
×
3839

3840
func (bexp *NotBoolExp) isConstant() bool {
1✔
3841
        return bexp.exp.isConstant()
1✔
3842
}
1✔
3843

3844
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
3845
        return nil
7✔
3846
}
7✔
3847

3848
func (bexp *NotBoolExp) String() string {
1✔
3849
        return "NOT " + bexp.exp.String()
1✔
3850
}
1✔
3851

3852
type LikeBoolExp struct {
3853
        val     ValueExp
3854
        notLike bool
3855
        pattern ValueExp
3856
}
3857

3858
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
3✔
3859
        return &LikeBoolExp{
3✔
3860
                val:     val,
3✔
3861
                notLike: notLike,
3✔
3862
                pattern: pattern,
3✔
3863
        }
3✔
3864
}
3✔
3865

3866
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
3867
        if bexp.val == nil || bexp.pattern == nil {
4✔
3868
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
3869
        }
1✔
3870

3871
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
2✔
3872
        if err != nil {
3✔
3873
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
3874
        }
1✔
3875

3876
        return BooleanType, nil
1✔
3877
}
3878

3879
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
3880
        if bexp.val == nil || bexp.pattern == nil {
9✔
3881
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
3882
        }
2✔
3883

3884
        if t != BooleanType {
7✔
3885
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
3886
        }
2✔
3887

3888
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
3889
        if err != nil {
4✔
3890
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
3891
        }
1✔
3892

3893
        return nil
2✔
3894
}
3895

3896
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
135✔
3897
        if bexp.val == nil || bexp.pattern == nil {
136✔
3898
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
3899
        }
1✔
3900

3901
        val, err := bexp.val.substitute(params)
134✔
3902
        if err != nil {
134✔
3903
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
3904
        }
×
3905

3906
        pattern, err := bexp.pattern.substitute(params)
134✔
3907
        if err != nil {
134✔
3908
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
3909
        }
×
3910

3911
        return &LikeBoolExp{
134✔
3912
                val:     val,
134✔
3913
                notLike: bexp.notLike,
134✔
3914
                pattern: pattern,
134✔
3915
        }, nil
134✔
3916
}
3917

3918
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
142✔
3919
        if bexp.val == nil || bexp.pattern == nil {
143✔
3920
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
3921
        }
1✔
3922

3923
        rval, err := bexp.val.reduce(tx, row, implicitTable)
141✔
3924
        if err != nil {
141✔
3925
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
3926
        }
×
3927

3928
        if rval.IsNull() {
142✔
3929
                return &Bool{val: false}, nil
1✔
3930
        }
1✔
3931

3932
        rvalStr, ok := rval.RawValue().(string)
140✔
3933
        if !ok {
141✔
3934
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
3935
        }
1✔
3936

3937
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
139✔
3938
        if err != nil {
139✔
3939
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
3940
        }
×
3941

3942
        if rpattern.Type() != VarcharType {
139✔
3943
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
3944
        }
×
3945

3946
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
139✔
3947
        if err != nil {
139✔
3948
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
3949
        }
×
3950

3951
        return &Bool{val: matched != bexp.notLike}, nil
139✔
3952
}
3953

3954
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3955
        return bexp
1✔
3956
}
1✔
3957

3958
func (bexp *LikeBoolExp) isConstant() bool {
2✔
3959
        return false
2✔
3960
}
2✔
3961

3962
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
3963
        return nil
8✔
3964
}
8✔
3965

3966
func (bexp *LikeBoolExp) String() string {
3✔
3967
        return fmt.Sprintf("(%s LIKE %s)", bexp.val.String(), bexp.pattern.String())
3✔
3968
}
3✔
3969

3970
type CmpBoolExp struct {
3971
        op          CmpOperator
3972
        left, right ValueExp
3973
}
3974

3975
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
66✔
3976
        return &CmpBoolExp{
66✔
3977
                op:    op,
66✔
3978
                left:  left,
66✔
3979
                right: right,
66✔
3980
        }
66✔
3981
}
66✔
3982

3983
func (bexp *CmpBoolExp) Left() ValueExp {
×
3984
        return bexp.left
×
3985
}
×
3986

3987
func (bexp *CmpBoolExp) Right() ValueExp {
×
3988
        return bexp.right
×
3989
}
×
3990

3991
func (bexp *CmpBoolExp) OP() CmpOperator {
×
3992
        return bexp.op
×
3993
}
×
3994

3995
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
59✔
3996
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
59✔
3997
        if err != nil {
59✔
3998
                return AnyType, err
×
3999
        }
×
4000

4001
        tright, err := bexp.right.inferType(cols, params, implicitTable)
59✔
4002
        if err != nil {
62✔
4003
                return AnyType, err
3✔
4004
        }
3✔
4005

4006
        // unification step
4007

4008
        if tleft == tright {
65✔
4009
                return BooleanType, nil
9✔
4010
        }
9✔
4011

4012
        if tleft != AnyType && tright != AnyType {
51✔
4013
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4014
        }
4✔
4015

4016
        if tleft == AnyType {
47✔
4017
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
4✔
4018
                if err != nil {
4✔
4019
                        return AnyType, err
×
4020
                }
×
4021
        }
4022

4023
        if tright == AnyType {
82✔
4024
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
39✔
4025
                if err != nil {
39✔
4026
                        return AnyType, err
×
4027
                }
×
4028
        }
4029

4030
        return BooleanType, nil
43✔
4031
}
4032

4033
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
41✔
4034
        if t != BooleanType {
42✔
4035
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4036
        }
1✔
4037

4038
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4039

40✔
4040
        return err
40✔
4041
}
4042

4043
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
3,253✔
4044
        rlexp, err := bexp.left.substitute(params)
3,253✔
4045
        if err != nil {
3,253✔
4046
                return nil, err
×
4047
        }
×
4048

4049
        rrexp, err := bexp.right.substitute(params)
3,253✔
4050
        if err != nil {
3,254✔
4051
                return nil, err
1✔
4052
        }
1✔
4053

4054
        bexp.left = rlexp
3,252✔
4055
        bexp.right = rrexp
3,252✔
4056

3,252✔
4057
        return bexp, nil
3,252✔
4058
}
4059

4060
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,131✔
4061
        vl, err := bexp.left.reduce(tx, row, implicitTable)
3,131✔
4062
        if err != nil {
3,134✔
4063
                return nil, err
3✔
4064
        }
3✔
4065

4066
        vr, err := bexp.right.reduce(tx, row, implicitTable)
3,128✔
4067
        if err != nil {
3,130✔
4068
                return nil, err
2✔
4069
        }
2✔
4070

4071
        r, err := vl.Compare(vr)
3,126✔
4072
        if err != nil {
3,130✔
4073
                return nil, err
4✔
4074
        }
4✔
4075

4076
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
3,122✔
4077
}
4078

4079
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
167✔
4080
        return &CmpBoolExp{
167✔
4081
                op:    bexp.op,
167✔
4082
                left:  bexp.left.reduceSelectors(row, implicitTable),
167✔
4083
                right: bexp.right.reduceSelectors(row, implicitTable),
167✔
4084
        }
167✔
4085
}
167✔
4086

4087
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4088
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4089
}
2✔
4090

4091
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
489✔
4092
        matchingFunc := func(left, right ValueExp) (*ColSelector, ValueExp, bool) {
1,074✔
4093
                s, isSel := bexp.left.(*ColSelector)
585✔
4094
                if isSel && s.col != revCol && bexp.right.isConstant() {
978✔
4095
                        return s, right, true
393✔
4096
                }
393✔
4097
                return nil, nil, false
192✔
4098
        }
4099

4100
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
489✔
4101
        if !ok {
585✔
4102
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
96✔
4103
        }
96✔
4104

4105
        if !ok {
585✔
4106
                return nil
96✔
4107
        }
96✔
4108

4109
        aggFn, t, col := sel.resolve(table.name)
393✔
4110
        if aggFn != "" || t != asTable {
407✔
4111
                return nil
14✔
4112
        }
14✔
4113

4114
        column, err := table.GetColumnByName(col)
379✔
4115
        if err != nil {
380✔
4116
                return err
1✔
4117
        }
1✔
4118

4119
        val, err := c.substitute(params)
378✔
4120
        if errors.Is(err, ErrMissingParameter) {
437✔
4121
                // TODO: not supported when parameters are not provided during query resolution
59✔
4122
                return nil
59✔
4123
        }
59✔
4124
        if err != nil {
319✔
4125
                return err
×
4126
        }
×
4127

4128
        rval, err := val.reduce(nil, nil, table.name)
319✔
4129
        if err != nil {
320✔
4130
                return err
1✔
4131
        }
1✔
4132

4133
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
318✔
4134
}
4135

4136
func (bexp *CmpBoolExp) String() string {
19✔
4137
        opStr := CmpOperatorToString(bexp.op)
19✔
4138
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
19✔
4139
}
19✔
4140

4141
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
318✔
4142
        currRange, ranged := rangesByColID[colID]
318✔
4143
        var newRange *typedValueRange
318✔
4144

318✔
4145
        switch cmp {
318✔
4146
        case EQ:
250✔
4147
                {
500✔
4148
                        newRange = &typedValueRange{
250✔
4149
                                lRange: &typedValueSemiRange{
250✔
4150
                                        val:       val,
250✔
4151
                                        inclusive: true,
250✔
4152
                                },
250✔
4153
                                hRange: &typedValueSemiRange{
250✔
4154
                                        val:       val,
250✔
4155
                                        inclusive: true,
250✔
4156
                                },
250✔
4157
                        }
250✔
4158
                }
250✔
4159
        case LT:
13✔
4160
                {
26✔
4161
                        newRange = &typedValueRange{
13✔
4162
                                hRange: &typedValueSemiRange{
13✔
4163
                                        val: val,
13✔
4164
                                },
13✔
4165
                        }
13✔
4166
                }
13✔
4167
        case LE:
10✔
4168
                {
20✔
4169
                        newRange = &typedValueRange{
10✔
4170
                                hRange: &typedValueSemiRange{
10✔
4171
                                        val:       val,
10✔
4172
                                        inclusive: true,
10✔
4173
                                },
10✔
4174
                        }
10✔
4175
                }
10✔
4176
        case GT:
18✔
4177
                {
36✔
4178
                        newRange = &typedValueRange{
18✔
4179
                                lRange: &typedValueSemiRange{
18✔
4180
                                        val: val,
18✔
4181
                                },
18✔
4182
                        }
18✔
4183
                }
18✔
4184
        case GE:
16✔
4185
                {
32✔
4186
                        newRange = &typedValueRange{
16✔
4187
                                lRange: &typedValueSemiRange{
16✔
4188
                                        val:       val,
16✔
4189
                                        inclusive: true,
16✔
4190
                                },
16✔
4191
                        }
16✔
4192
                }
16✔
4193
        case NE:
11✔
4194
                {
22✔
4195
                        return nil
11✔
4196
                }
11✔
4197
        }
4198

4199
        if !ranged {
611✔
4200
                rangesByColID[colID] = newRange
304✔
4201
                return nil
304✔
4202
        }
304✔
4203

4204
        return currRange.refineWith(newRange)
3✔
4205
}
4206

4207
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
3,122✔
4208
        switch {
3,122✔
4209
        case cmp == 0:
551✔
4210
                {
1,102✔
4211
                        return op == EQ || op == LE || op == GE
551✔
4212
                }
551✔
4213
        case cmp < 0:
1,409✔
4214
                {
2,818✔
4215
                        return op == NE || op == LT || op == LE
1,409✔
4216
                }
1,409✔
4217
        case cmp > 0:
1,162✔
4218
                {
2,324✔
4219
                        return op == NE || op == GT || op == GE
1,162✔
4220
                }
1,162✔
4221
        }
4222
        return false
×
4223
}
4224

4225
type BinBoolExp struct {
4226
        op          LogicOperator
4227
        left, right ValueExp
4228
}
4229

4230
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
18✔
4231
        bexp := &BinBoolExp{
18✔
4232
                op: op,
18✔
4233
        }
18✔
4234

18✔
4235
        bexp.left = lrexp
18✔
4236
        bexp.right = rrexp
18✔
4237

18✔
4238
        return bexp
18✔
4239
}
18✔
4240

4241
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
19✔
4242
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
4243
        if err != nil {
19✔
4244
                return AnyType, err
×
4245
        }
×
4246

4247
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
19✔
4248
        if err != nil {
21✔
4249
                return AnyType, err
2✔
4250
        }
2✔
4251

4252
        return BooleanType, nil
17✔
4253
}
4254

4255
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
22✔
4256
        if t != BooleanType {
25✔
4257
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
3✔
4258
        }
3✔
4259

4260
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
4261
        if err != nil {
20✔
4262
                return err
1✔
4263
        }
1✔
4264

4265
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
18✔
4266
        if err != nil {
18✔
4267
                return err
×
4268
        }
×
4269

4270
        return nil
18✔
4271
}
4272

4273
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
470✔
4274
        rlexp, err := bexp.left.substitute(params)
470✔
4275
        if err != nil {
470✔
4276
                return nil, err
×
4277
        }
×
4278

4279
        rrexp, err := bexp.right.substitute(params)
470✔
4280
        if err != nil {
470✔
4281
                return nil, err
×
4282
        }
×
4283

4284
        bexp.left = rlexp
470✔
4285
        bexp.right = rrexp
470✔
4286

470✔
4287
        return bexp, nil
470✔
4288
}
4289

4290
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
473✔
4291
        vl, err := bexp.left.reduce(tx, row, implicitTable)
473✔
4292
        if err != nil {
474✔
4293
                return nil, err
1✔
4294
        }
1✔
4295

4296
        bl, isBool := vl.(*Bool)
472✔
4297
        if !isBool {
472✔
4298
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
4299
        }
×
4300

4301
        // short-circuit evaluation
4302
        if (bl.val && bexp.op == OR) || (!bl.val && bexp.op == AND) {
648✔
4303
                return &Bool{val: bl.val}, nil
176✔
4304
        }
176✔
4305

4306
        vr, err := bexp.right.reduce(tx, row, implicitTable)
296✔
4307
        if err != nil {
297✔
4308
                return nil, err
1✔
4309
        }
1✔
4310

4311
        br, isBool := vr.(*Bool)
295✔
4312
        if !isBool {
295✔
4313
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
4314
        }
×
4315

4316
        switch bexp.op {
295✔
4317
        case AND:
273✔
4318
                {
546✔
4319
                        return &Bool{val: bl.val && br.val}, nil
273✔
4320
                }
273✔
4321
        case OR:
22✔
4322
                {
44✔
4323
                        return &Bool{val: bl.val || br.val}, nil
22✔
4324
                }
22✔
4325
        }
4326

4327
        return nil, ErrUnexpected
×
4328
}
4329

4330
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
4331
        return &BinBoolExp{
15✔
4332
                op:    bexp.op,
15✔
4333
                left:  bexp.left.reduceSelectors(row, implicitTable),
15✔
4334
                right: bexp.right.reduceSelectors(row, implicitTable),
15✔
4335
        }
15✔
4336
}
15✔
4337

4338
func (bexp *BinBoolExp) isConstant() bool {
1✔
4339
        return bexp.left.isConstant() && bexp.right.isConstant()
1✔
4340
}
1✔
4341

4342
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
150✔
4343
        if bexp.op == AND {
287✔
4344
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
137✔
4345
                if err != nil {
137✔
4346
                        return err
×
4347
                }
×
4348

4349
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
137✔
4350
        }
4351

4352
        lRanges := make(map[uint32]*typedValueRange)
13✔
4353
        rRanges := make(map[uint32]*typedValueRange)
13✔
4354

13✔
4355
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
13✔
4356
        if err != nil {
13✔
4357
                return err
×
4358
        }
×
4359

4360
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
13✔
4361
        if err != nil {
13✔
4362
                return err
×
4363
        }
×
4364

4365
        for colID, lr := range lRanges {
20✔
4366
                rr, ok := rRanges[colID]
7✔
4367
                if !ok {
9✔
4368
                        continue
2✔
4369
                }
4370

4371
                err = lr.extendWith(rr)
5✔
4372
                if err != nil {
5✔
4373
                        return err
×
4374
                }
×
4375

4376
                rangesByColID[colID] = lr
5✔
4377
        }
4378

4379
        return nil
13✔
4380
}
4381

4382
func (bexp *BinBoolExp) String() string {
7✔
4383
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
7✔
4384
}
7✔
4385

4386
type ExistsBoolExp struct {
4387
        q DataSource
4388
}
4389

4390
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4391
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
4392
}
1✔
4393

4394
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
4395
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
4396
}
1✔
4397

4398
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
4399
        return bexp, nil
1✔
4400
}
1✔
4401

4402
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
4403
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
2✔
4404
}
2✔
4405

4406
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4407
        return bexp
1✔
4408
}
1✔
4409

4410
func (bexp *ExistsBoolExp) isConstant() bool {
2✔
4411
        return false
2✔
4412
}
2✔
4413

4414
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
4415
        return nil
1✔
4416
}
1✔
4417

NEW
4418
func (bexp *ExistsBoolExp) String() string {
×
NEW
4419
        return ""
×
NEW
4420
}
×
4421

4422
type InSubQueryExp struct {
4423
        val   ValueExp
4424
        notIn bool
4425
        q     *SelectStmt
4426
}
4427

4428
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4429
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
4430
}
1✔
4431

4432
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
4433
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
4434
}
1✔
4435

4436
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
4437
        return bexp, nil
1✔
4438
}
1✔
4439

4440
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
4441
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
4442
}
2✔
4443

4444
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4445
        return bexp
1✔
4446
}
1✔
4447

4448
func (bexp *InSubQueryExp) isConstant() bool {
1✔
4449
        return false
1✔
4450
}
1✔
4451

4452
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
4453
        return nil
1✔
4454
}
1✔
4455

NEW
4456
func (bexp *InSubQueryExp) String() string {
×
NEW
4457
        return ""
×
NEW
4458
}
×
4459

4460
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
4461
type InListExp struct {
4462
        val    ValueExp
4463
        notIn  bool
4464
        values []ValueExp
4465
}
4466

4467
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
4468
        t, err := bexp.val.inferType(cols, params, implicitTable)
6✔
4469
        if err != nil {
7✔
4470
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
4471
        }
1✔
4472

4473
        for _, v := range bexp.values {
15✔
4474
                err = v.requiresType(t, cols, params, implicitTable)
10✔
4475
                if err != nil {
11✔
4476
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
4477
                }
1✔
4478
        }
4479

4480
        return BooleanType, nil
4✔
4481
}
4482

4483
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
4484
        _, err := bexp.inferType(cols, params, implicitTable)
2✔
4485
        if err != nil {
3✔
4486
                return err
1✔
4487
        }
1✔
4488

4489
        if t != BooleanType {
1✔
4490
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
×
4491
        }
×
4492

4493
        return nil
1✔
4494
}
4495

4496
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
115✔
4497
        val, err := bexp.val.substitute(params)
115✔
4498
        if err != nil {
115✔
4499
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
4500
        }
×
4501

4502
        values := make([]ValueExp, len(bexp.values))
115✔
4503

115✔
4504
        for i, val := range bexp.values {
245✔
4505
                values[i], err = val.substitute(params)
130✔
4506
                if err != nil {
130✔
4507
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
4508
                }
×
4509
        }
4510

4511
        return &InListExp{
115✔
4512
                val:    val,
115✔
4513
                notIn:  bexp.notIn,
115✔
4514
                values: values,
115✔
4515
        }, nil
115✔
4516
}
4517

4518
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
115✔
4519
        rval, err := bexp.val.reduce(tx, row, implicitTable)
115✔
4520
        if err != nil {
116✔
4521
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
4522
        }
1✔
4523

4524
        var found bool
114✔
4525

114✔
4526
        for _, v := range bexp.values {
241✔
4527
                rv, err := v.reduce(tx, row, implicitTable)
127✔
4528
                if err != nil {
128✔
4529
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
4530
                }
1✔
4531

4532
                r, err := rval.Compare(rv)
126✔
4533
                if err != nil {
127✔
4534
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
4535
                }
1✔
4536

4537
                if r == 0 {
140✔
4538
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
15✔
4539
                        found = found || true
15✔
4540
                }
15✔
4541
        }
4542

4543
        return &Bool{val: found != bexp.notIn}, nil
112✔
4544
}
4545

4546
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
4547
        values := make([]ValueExp, len(bexp.values))
10✔
4548

10✔
4549
        for i, val := range bexp.values {
20✔
4550
                values[i] = val.reduceSelectors(row, implicitTable)
10✔
4551
        }
10✔
4552

4553
        return &InListExp{
10✔
4554
                val:    bexp.val.reduceSelectors(row, implicitTable),
10✔
4555
                values: values,
10✔
4556
        }
10✔
4557
}
4558

4559
func (bexp *InListExp) isConstant() bool {
1✔
4560
        return false
1✔
4561
}
1✔
4562

4563
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
20✔
4564
        // TODO: may be determiined by smallest and bigggest value in the list
20✔
4565
        return nil
20✔
4566
}
20✔
4567

4568
func (bexp *InListExp) String() string {
1✔
4569
        values := make([]string, len(bexp.values))
1✔
4570
        for i, exp := range bexp.values {
5✔
4571
                values[i] = exp.String()
4✔
4572
        }
4✔
4573
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
1✔
4574
}
4575

4576
type FnDataSourceStmt struct {
4577
        fnCall *FnCall
4578
        as     string
4579
}
4580

4581
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
4582
        return tx, nil
×
4583
}
×
4584

4585
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
4586
        return nil
1✔
4587
}
1✔
4588

4589
func (stmt *FnDataSourceStmt) Alias() string {
22✔
4590
        if stmt.as != "" {
24✔
4591
                return stmt.as
2✔
4592
        }
2✔
4593

4594
        switch strings.ToUpper(stmt.fnCall.fn) {
20✔
4595
        case DatabasesFnCall:
3✔
4596
                {
6✔
4597
                        return "databases"
3✔
4598
                }
3✔
4599
        case TablesFnCall:
5✔
4600
                {
10✔
4601
                        return "tables"
5✔
4602
                }
5✔
4603
        case TableFnCall:
×
4604
                {
×
4605
                        return "table"
×
4606
                }
×
4607
        case UsersFnCall:
7✔
4608
                {
14✔
4609
                        return "users"
7✔
4610
                }
7✔
4611
        case ColumnsFnCall:
3✔
4612
                {
6✔
4613
                        return "columns"
3✔
4614
                }
3✔
4615
        case IndexesFnCall:
2✔
4616
                {
4✔
4617
                        return "indexes"
2✔
4618
                }
2✔
4619
        }
4620

4621
        // not reachable
4622
        return ""
×
4623
}
4624

4625
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
23✔
4626
        if stmt.fnCall == nil {
23✔
4627
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
×
4628
        }
×
4629

4630
        switch strings.ToUpper(stmt.fnCall.fn) {
23✔
4631
        case DatabasesFnCall:
5✔
4632
                {
10✔
4633
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
5✔
4634
                }
5✔
4635
        case TablesFnCall:
5✔
4636
                {
10✔
4637
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
5✔
4638
                }
5✔
4639
        case TableFnCall:
×
4640
                {
×
4641
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
×
4642
                }
×
4643
        case UsersFnCall:
7✔
4644
                {
14✔
4645
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
7✔
4646
                }
7✔
4647
        case ColumnsFnCall:
3✔
4648
                {
6✔
4649
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
4650
                }
3✔
4651
        case IndexesFnCall:
3✔
4652
                {
6✔
4653
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
3✔
4654
                }
3✔
4655
        }
4656

4657
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
×
4658
}
4659

4660
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
4661
        if len(stmt.fnCall.params) > 0 {
5✔
4662
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
×
4663
        }
×
4664

4665
        cols := make([]ColDescriptor, 1)
5✔
4666
        cols[0] = ColDescriptor{
5✔
4667
                Column: "name",
5✔
4668
                Type:   VarcharType,
5✔
4669
        }
5✔
4670

5✔
4671
        var dbs []string
5✔
4672

5✔
4673
        if tx.engine.multidbHandler == nil {
6✔
4674
                return nil, ErrUnspecifiedMultiDBHandler
1✔
4675
        } else {
5✔
4676
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
4✔
4677
                if err != nil {
4✔
4678
                        return nil, err
×
4679
                }
×
4680
        }
4681

4682
        values := make([][]ValueExp, len(dbs))
4✔
4683

4✔
4684
        for i, db := range dbs {
12✔
4685
                values[i] = []ValueExp{&Varchar{val: db}}
8✔
4686
        }
8✔
4687

4688
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
4✔
4689
}
4690

4691
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
4692
        if len(stmt.fnCall.params) > 0 {
5✔
4693
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
4694
        }
×
4695

4696
        cols := make([]ColDescriptor, 1)
5✔
4697
        cols[0] = ColDescriptor{
5✔
4698
                Column: "name",
5✔
4699
                Type:   VarcharType,
5✔
4700
        }
5✔
4701

5✔
4702
        tables := tx.catalog.GetTables()
5✔
4703

5✔
4704
        values := make([][]ValueExp, len(tables))
5✔
4705

5✔
4706
        for i, t := range tables {
14✔
4707
                values[i] = []ValueExp{&Varchar{val: t.name}}
9✔
4708
        }
9✔
4709

4710
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
5✔
4711
}
4712

4713
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
×
4714
        cols := []ColDescriptor{
×
4715
                {
×
4716
                        Column: "column_name",
×
4717
                        Type:   VarcharType,
×
4718
                },
×
4719
                {
×
4720
                        Column: "type_name",
×
4721
                        Type:   VarcharType,
×
4722
                },
×
4723
                {
×
4724
                        Column: "is_nullable",
×
4725
                        Type:   BooleanType,
×
4726
                },
×
4727
                {
×
4728
                        Column: "is_indexed",
×
4729
                        Type:   VarcharType,
×
4730
                },
×
4731
                {
×
4732
                        Column: "is_auto_increment",
×
4733
                        Type:   BooleanType,
×
4734
                },
×
4735
                {
×
4736
                        Column: "is_unique",
×
4737
                        Type:   BooleanType,
×
4738
                },
×
4739
        }
×
4740

×
4741
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
4742
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
4743
        if err != nil {
×
4744
                return nil, err
×
4745
        }
×
4746

4747
        values := make([][]ValueExp, len(table.cols))
×
4748

×
4749
        for i, c := range table.cols {
×
4750
                index := "NO"
×
4751

×
4752
                indexed, err := table.IsIndexed(c.Name())
×
4753
                if err != nil {
×
4754
                        return nil, err
×
4755
                }
×
4756
                if indexed {
×
4757
                        index = "YES"
×
4758
                }
×
4759

4760
                if table.PrimaryIndex().IncludesCol(c.ID()) {
×
4761
                        index = "PRIMARY KEY"
×
4762
                }
×
4763

4764
                var unique bool
×
4765
                for _, index := range table.GetIndexesByColID(c.ID()) {
×
4766
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
4767
                                unique = true
×
4768
                                break
×
4769
                        }
4770
                }
4771

4772
                var maxLen string
×
4773

×
4774
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
4775
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
×
4776
                }
×
4777

4778
                values[i] = []ValueExp{
×
4779
                        &Varchar{val: c.colName},
×
4780
                        &Varchar{val: c.Type() + maxLen},
×
4781
                        &Bool{val: c.IsNullable()},
×
4782
                        &Varchar{val: index},
×
4783
                        &Bool{val: c.IsAutoIncremental()},
×
4784
                        &Bool{val: unique},
×
4785
                }
×
4786
        }
4787

4788
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
×
4789
}
4790

4791
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
7✔
4792
        if len(stmt.fnCall.params) > 0 {
7✔
4793
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
×
4794
        }
×
4795

4796
        cols := make([]ColDescriptor, 2)
7✔
4797
        cols[0] = ColDescriptor{
7✔
4798
                Column: "name",
7✔
4799
                Type:   VarcharType,
7✔
4800
        }
7✔
4801
        cols[1] = ColDescriptor{
7✔
4802
                Column: "permission",
7✔
4803
                Type:   VarcharType,
7✔
4804
        }
7✔
4805

7✔
4806
        var users []User
7✔
4807

7✔
4808
        if tx.engine.multidbHandler == nil {
7✔
4809
                return nil, ErrUnspecifiedMultiDBHandler
×
4810
        } else {
7✔
4811
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
7✔
4812
                if err != nil {
7✔
4813
                        return nil, err
×
4814
                }
×
4815
        }
4816

4817
        values := make([][]ValueExp, len(users))
7✔
4818

7✔
4819
        for i, user := range users {
21✔
4820
                var perm string
14✔
4821

14✔
4822
                switch user.Permission() {
14✔
4823
                case 1:
4✔
4824
                        {
8✔
4825
                                perm = "READ"
4✔
4826
                        }
4✔
4827
                case 2:
2✔
4828
                        {
4✔
4829
                                perm = "READ/WRITE"
2✔
4830
                        }
2✔
4831
                case 254:
3✔
4832
                        {
6✔
4833
                                perm = "ADMIN"
3✔
4834
                        }
3✔
4835
                default:
5✔
4836
                        {
10✔
4837
                                perm = "SYSADMIN"
5✔
4838
                        }
5✔
4839
                }
4840

4841
                values[i] = []ValueExp{
14✔
4842
                        &Varchar{val: user.Username()},
14✔
4843
                        &Varchar{val: perm},
14✔
4844
                }
14✔
4845
        }
4846

4847
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
7✔
4848
}
4849

4850
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
4851
        if len(stmt.fnCall.params) != 1 {
3✔
4852
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
×
4853
        }
×
4854

4855
        cols := []ColDescriptor{
3✔
4856
                {
3✔
4857
                        Column: "table",
3✔
4858
                        Type:   VarcharType,
3✔
4859
                },
3✔
4860
                {
3✔
4861
                        Column: "name",
3✔
4862
                        Type:   VarcharType,
3✔
4863
                },
3✔
4864
                {
3✔
4865
                        Column: "type",
3✔
4866
                        Type:   VarcharType,
3✔
4867
                },
3✔
4868
                {
3✔
4869
                        Column: "max_length",
3✔
4870
                        Type:   IntegerType,
3✔
4871
                },
3✔
4872
                {
3✔
4873
                        Column: "nullable",
3✔
4874
                        Type:   BooleanType,
3✔
4875
                },
3✔
4876
                {
3✔
4877
                        Column: "auto_increment",
3✔
4878
                        Type:   BooleanType,
3✔
4879
                },
3✔
4880
                {
3✔
4881
                        Column: "indexed",
3✔
4882
                        Type:   BooleanType,
3✔
4883
                },
3✔
4884
                {
3✔
4885
                        Column: "primary",
3✔
4886
                        Type:   BooleanType,
3✔
4887
                },
3✔
4888
                {
3✔
4889
                        Column: "unique",
3✔
4890
                        Type:   BooleanType,
3✔
4891
                },
3✔
4892
        }
3✔
4893

3✔
4894
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
4895
        if err != nil {
3✔
4896
                return nil, err
×
4897
        }
×
4898

4899
        tableName, err := val.reduce(tx, nil, "")
3✔
4900
        if err != nil {
3✔
4901
                return nil, err
×
4902
        }
×
4903

4904
        if tableName.Type() != VarcharType {
3✔
4905
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
4906
        }
×
4907

4908
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
4909
        if err != nil {
3✔
4910
                return nil, err
×
4911
        }
×
4912

4913
        values := make([][]ValueExp, len(table.cols))
3✔
4914

3✔
4915
        for i, c := range table.cols {
11✔
4916
                indexed, err := table.IsIndexed(c.Name())
8✔
4917
                if err != nil {
8✔
4918
                        return nil, err
×
4919
                }
×
4920

4921
                var unique bool
8✔
4922
                for _, index := range table.indexesByColID[c.id] {
16✔
4923
                        if index.IsUnique() && len(index.Cols()) == 1 {
11✔
4924
                                unique = true
3✔
4925
                                break
3✔
4926
                        }
4927
                }
4928

4929
                values[i] = []ValueExp{
8✔
4930
                        &Varchar{val: table.name},
8✔
4931
                        &Varchar{val: c.colName},
8✔
4932
                        &Varchar{val: c.colType},
8✔
4933
                        &Integer{val: int64(c.MaxLen())},
8✔
4934
                        &Bool{val: c.IsNullable()},
8✔
4935
                        &Bool{val: c.autoIncrement},
8✔
4936
                        &Bool{val: indexed},
8✔
4937
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
8✔
4938
                        &Bool{val: unique},
8✔
4939
                }
8✔
4940
        }
4941

4942
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
3✔
4943
}
4944

4945
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
4946
        if len(stmt.fnCall.params) != 1 {
3✔
4947
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
×
4948
        }
×
4949

4950
        cols := []ColDescriptor{
3✔
4951
                {
3✔
4952
                        Column: "table",
3✔
4953
                        Type:   VarcharType,
3✔
4954
                },
3✔
4955
                {
3✔
4956
                        Column: "name",
3✔
4957
                        Type:   VarcharType,
3✔
4958
                },
3✔
4959
                {
3✔
4960
                        Column: "unique",
3✔
4961
                        Type:   BooleanType,
3✔
4962
                },
3✔
4963
                {
3✔
4964
                        Column: "primary",
3✔
4965
                        Type:   BooleanType,
3✔
4966
                },
3✔
4967
        }
3✔
4968

3✔
4969
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
4970
        if err != nil {
3✔
4971
                return nil, err
×
4972
        }
×
4973

4974
        tableName, err := val.reduce(tx, nil, "")
3✔
4975
        if err != nil {
3✔
4976
                return nil, err
×
4977
        }
×
4978

4979
        if tableName.Type() != VarcharType {
3✔
4980
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
4981
        }
×
4982

4983
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
4984
        if err != nil {
3✔
4985
                return nil, err
×
4986
        }
×
4987

4988
        values := make([][]ValueExp, len(table.indexes))
3✔
4989

3✔
4990
        for i, index := range table.indexes {
10✔
4991
                values[i] = []ValueExp{
7✔
4992
                        &Varchar{val: table.name},
7✔
4993
                        &Varchar{val: index.Name()},
7✔
4994
                        &Bool{val: index.unique},
7✔
4995
                        &Bool{val: index.IsPrimary()},
7✔
4996
                }
7✔
4997
        }
7✔
4998

4999
        return newValuesRowReader(tx, params, cols, stmt.Alias(), values)
3✔
5000
}
5001

5002
// DropTableStmt represents a statement to delete a table.
5003
type DropTableStmt struct {
5004
        table string
5005
}
5006

5007
func NewDropTableStmt(table string) *DropTableStmt {
6✔
5008
        return &DropTableStmt{table: table}
6✔
5009
}
6✔
5010

5011
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5012
        return nil
1✔
5013
}
1✔
5014

5015
/*
5016
Exec executes the delete table statement.
5017
It the table exists, if not it does nothing.
5018
If the table exists, it deletes all the indexes and the table itself.
5019
Note that this is a soft delete of the index and table key,
5020
the data is not deleted, but the metadata is updated.
5021
*/
5022
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
7✔
5023
        if !tx.catalog.ExistTable(stmt.table) {
8✔
5024
                return nil, ErrTableDoesNotExist
1✔
5025
        }
1✔
5026

5027
        table, err := tx.catalog.GetTableByName(stmt.table)
6✔
5028
        if err != nil {
6✔
5029
                return nil, err
×
5030
        }
×
5031

5032
        // delete table
5033
        mappedKey := MapKey(
6✔
5034
                tx.sqlPrefix(),
6✔
5035
                catalogTablePrefix,
6✔
5036
                EncodeID(DatabaseID),
6✔
5037
                EncodeID(table.id),
6✔
5038
        )
6✔
5039
        err = tx.delete(ctx, mappedKey)
6✔
5040
        if err != nil {
6✔
5041
                return nil, err
×
5042
        }
×
5043

5044
        // delete columns
5045
        cols := table.ColumnsByID()
6✔
5046
        for _, col := range cols {
26✔
5047
                mappedKey := MapKey(
20✔
5048
                        tx.sqlPrefix(),
20✔
5049
                        catalogColumnPrefix,
20✔
5050
                        EncodeID(DatabaseID),
20✔
5051
                        EncodeID(col.table.id),
20✔
5052
                        EncodeID(col.id),
20✔
5053
                        []byte(col.colType),
20✔
5054
                )
20✔
5055
                err = tx.delete(ctx, mappedKey)
20✔
5056
                if err != nil {
20✔
5057
                        return nil, err
×
5058
                }
×
5059
        }
5060

5061
        // delete checks
5062
        for name := range table.checkConstraints {
6✔
NEW
5063
                key := MapKey(
×
NEW
5064
                        tx.sqlPrefix(),
×
NEW
5065
                        catalogCheckPrefix,
×
NEW
5066
                        EncodeID(DatabaseID),
×
NEW
5067
                        EncodeID(table.id),
×
NEW
5068
                        []byte(name),
×
NEW
5069
                )
×
NEW
5070

×
NEW
5071
                if err := tx.delete(ctx, key); err != nil {
×
NEW
5072
                        return nil, err
×
NEW
5073
                }
×
5074
        }
5075

5076
        // delete indexes
5077
        for _, index := range table.indexes {
13✔
5078
                mappedKey := MapKey(
7✔
5079
                        tx.sqlPrefix(),
7✔
5080
                        catalogIndexPrefix,
7✔
5081
                        EncodeID(DatabaseID),
7✔
5082
                        EncodeID(table.id),
7✔
5083
                        EncodeID(index.id),
7✔
5084
                )
7✔
5085
                err = tx.delete(ctx, mappedKey)
7✔
5086
                if err != nil {
7✔
5087
                        return nil, err
×
5088
                }
×
5089

5090
                indexKey := MapKey(
7✔
5091
                        tx.sqlPrefix(),
7✔
5092
                        MappedPrefix,
7✔
5093
                        EncodeID(table.id),
7✔
5094
                        EncodeID(index.id),
7✔
5095
                )
7✔
5096
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
14✔
5097
                        return sqlTx.engine.store.DeleteIndex(indexKey)
7✔
5098
                })
7✔
5099
                if err != nil {
7✔
5100
                        return nil, err
×
5101
                }
×
5102
        }
5103

5104
        err = tx.catalog.deleteTable(table)
6✔
5105
        if err != nil {
6✔
5106
                return nil, err
×
5107
        }
×
5108

5109
        tx.mutatedCatalog = true
6✔
5110

6✔
5111
        return tx, nil
6✔
5112
}
5113

5114
// DropIndexStmt represents a statement to delete a table.
5115
type DropIndexStmt struct {
5116
        table string
5117
        cols  []string
5118
}
5119

5120
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
4✔
5121
        return &DropIndexStmt{table: table, cols: cols}
4✔
5122
}
4✔
5123

5124
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5125
        return nil
1✔
5126
}
1✔
5127

5128
/*
5129
Exec executes the delete index statement.
5130
If the index exists, it deletes it. Note that this is a soft delete of the index
5131
the data is not deleted, but the metadata is updated.
5132
*/
5133
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
5134
        if !tx.catalog.ExistTable(stmt.table) {
7✔
5135
                return nil, ErrTableDoesNotExist
1✔
5136
        }
1✔
5137

5138
        table, err := tx.catalog.GetTableByName(stmt.table)
5✔
5139
        if err != nil {
5✔
5140
                return nil, err
×
5141
        }
×
5142

5143
        cols := make([]*Column, len(stmt.cols))
5✔
5144

5✔
5145
        for i, colName := range stmt.cols {
10✔
5146
                col, err := table.GetColumnByName(colName)
5✔
5147
                if err != nil {
5✔
5148
                        return nil, err
×
5149
                }
×
5150

5151
                cols[i] = col
5✔
5152
        }
5153

5154
        index, err := table.GetIndexByName(indexName(table.name, cols))
5✔
5155
        if err != nil {
5✔
5156
                return nil, err
×
5157
        }
×
5158

5159
        // delete index
5160
        mappedKey := MapKey(
5✔
5161
                tx.sqlPrefix(),
5✔
5162
                catalogIndexPrefix,
5✔
5163
                EncodeID(DatabaseID),
5✔
5164
                EncodeID(table.id),
5✔
5165
                EncodeID(index.id),
5✔
5166
        )
5✔
5167
        err = tx.delete(ctx, mappedKey)
5✔
5168
        if err != nil {
5✔
5169
                return nil, err
×
5170
        }
×
5171

5172
        indexKey := MapKey(
5✔
5173
                tx.sqlPrefix(),
5✔
5174
                MappedPrefix,
5✔
5175
                EncodeID(table.id),
5✔
5176
                EncodeID(index.id),
5✔
5177
        )
5✔
5178

5✔
5179
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
9✔
5180
                return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
5181
        })
4✔
5182
        if err != nil {
5✔
5183
                return nil, err
×
5184
        }
×
5185

5186
        err = table.deleteIndex(index)
5✔
5187
        if err != nil {
6✔
5188
                return nil, err
1✔
5189
        }
1✔
5190

5191
        tx.mutatedCatalog = true
4✔
5192

4✔
5193
        return tx, nil
4✔
5194
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc