• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 18658094170

20 Oct 2025 04:13PM UTC coverage: 89.267% (+0.002%) from 89.265%
18658094170

Pull #2076

gh-ci

els-tmiller
fix spacing
Pull Request #2076: S3 Storage - Fargate Credentials

16 of 17 new or added lines in 4 files covered. (94.12%)

452 existing lines in 4 files now uncovered.

37974 of 42540 relevant lines covered (89.27%)

149951.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.2
/embedded/sql/stmt.go
1
/*
2
Copyright 2025 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix          = "CTL."
38
        catalogTablePrefix     = "CTL.TABLE."     // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix    = "CTL.COLUMN."    // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogIndexPrefix     = "CTL.INDEX."     // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
41
        catalogCheckPrefix     = "CTL.CHECK."     // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
42
        catalogPrivilegePrefix = "CTL.PRIVILEGE." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
43

44
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
45
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
46
)
47

48
const (
49
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
50
        PKIndexID  = uint32(0)
51
)
52

53
const (
54
        nullableFlag      byte = 1 << iota
55
        autoIncrementFlag byte = 1 << iota
56
)
57

58
const (
59
        revCol        = "_rev"
60
        txMetadataCol = "_tx_metadata"
61
)
62

63
var reservedColumns = map[string]struct{}{
64
        revCol:        {},
65
        txMetadataCol: {},
66
}
67

68
func isReservedCol(col string) bool {
15,685✔
69
        _, ok := reservedColumns[col]
15,685✔
70
        return ok
15,685✔
71
}
15,685✔
72

73
type SQLValueType = string
74

75
const (
76
        IntegerType   SQLValueType = "INTEGER"
77
        BooleanType   SQLValueType = "BOOLEAN"
78
        VarcharType   SQLValueType = "VARCHAR"
79
        UUIDType      SQLValueType = "UUID"
80
        BLOBType      SQLValueType = "BLOB"
81
        Float64Type   SQLValueType = "FLOAT"
82
        TimestampType SQLValueType = "TIMESTAMP"
83
        AnyType       SQLValueType = "ANY"
84
        JSONType      SQLValueType = "JSON"
85
)
86

87
func IsNumericType(t SQLValueType) bool {
221✔
88
        return t == IntegerType || t == Float64Type
221✔
89
}
221✔
90

91
type Permission = string
92

93
const (
94
        PermissionReadOnly  Permission = "READ"
95
        PermissionReadWrite Permission = "READWRITE"
96
        PermissionAdmin     Permission = "ADMIN"
97
        PermissionSysAdmin  Permission = "SYSADMIN"
98
)
99

100
func PermissionFromCode(code uint32) Permission {
629✔
101
        switch code {
629✔
102
        case 1:
16✔
103
                {
32✔
104
                        return PermissionReadOnly
16✔
105
                }
16✔
106
        case 2:
10✔
107
                {
20✔
108
                        return PermissionReadWrite
10✔
109
                }
10✔
110
        case 254:
20✔
111
                {
40✔
112
                        return PermissionAdmin
20✔
113
                }
20✔
114
        }
115
        return PermissionSysAdmin
583✔
116
}
117

118
type AggregateFn = string
119

120
const (
121
        COUNT AggregateFn = "COUNT"
122
        SUM   AggregateFn = "SUM"
123
        MAX   AggregateFn = "MAX"
124
        MIN   AggregateFn = "MIN"
125
        AVG   AggregateFn = "AVG"
126
)
127

128
type CmpOperator = int
129

130
const (
131
        EQ CmpOperator = iota
132
        NE
133
        LT
134
        LE
135
        GT
136
        GE
137
)
138

139
func CmpOperatorToString(op CmpOperator) string {
20✔
140
        switch op {
20✔
141
        case EQ:
8✔
142
                return "="
8✔
143
        case NE:
2✔
144
                return "!="
2✔
145
        case LT:
1✔
146
                return "<"
1✔
147
        case LE:
3✔
148
                return "<="
3✔
149
        case GT:
1✔
150
                return ">"
1✔
151
        case GE:
5✔
152
                return ">="
5✔
153
        }
154
        return ""
×
155
}
156

157
type LogicOperator = int
158

159
const (
160
        And LogicOperator = iota
161
        Or
162
)
163

164
func LogicOperatorToString(op LogicOperator) string {
31✔
165
        if op == And {
46✔
166
                return "AND"
15✔
167
        }
15✔
168
        return "OR"
16✔
169
}
170

171
type NumOperator = int
172

173
const (
174
        ADDOP NumOperator = iota
175
        SUBSOP
176
        DIVOP
177
        MULTOP
178
        MODOP
179
)
180

181
func NumOperatorString(op NumOperator) string {
18✔
182
        switch op {
18✔
183
        case ADDOP:
6✔
184
                return "+"
6✔
185
        case SUBSOP:
1✔
186
                return "-"
1✔
187
        case DIVOP:
5✔
188
                return "/"
5✔
189
        case MULTOP:
5✔
190
                return "*"
5✔
191
        case MODOP:
1✔
192
                return "%"
1✔
193
        }
194
        return ""
×
195
}
196

197
type JoinType = int
198

199
const (
200
        InnerJoin JoinType = iota
201
        LeftJoin
202
        RightJoin
203
)
204

205
type SQLStmt interface {
206
        readOnly() bool
207
        requiredPrivileges() []SQLPrivilege
208
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
209
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
210
}
211

212
type BeginTransactionStmt struct {
213
}
214

215
func (stmt *BeginTransactionStmt) readOnly() bool {
9✔
216
        return true
9✔
217
}
9✔
218

219
func (stmt *BeginTransactionStmt) requiredPrivileges() []SQLPrivilege {
9✔
220
        return nil
9✔
221
}
9✔
222

223
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
224
        return nil
3✔
225
}
3✔
226

227
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
61✔
228
        if tx.IsExplicitCloseRequired() {
62✔
229
                return nil, ErrNestedTxNotSupported
1✔
230
        }
1✔
231

232
        err := tx.RequireExplicitClose()
60✔
233
        if err == nil {
119✔
234
                // current tx can be reused as no changes were already made
59✔
235
                return tx, nil
59✔
236
        }
59✔
237

238
        // commit current transaction and start a fresh one
239

240
        err = tx.Commit(ctx)
1✔
241
        if err != nil {
1✔
242
                return nil, err
×
243
        }
×
244

245
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
246
}
247

248
type CommitStmt struct {
249
}
250

251
func (stmt *CommitStmt) readOnly() bool {
113✔
252
        return true
113✔
253
}
113✔
254

255
func (stmt *CommitStmt) requiredPrivileges() []SQLPrivilege {
113✔
256
        return nil
113✔
257
}
113✔
258

259
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
260
        return nil
1✔
261
}
1✔
262

263
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
264
        if !tx.IsExplicitCloseRequired() {
160✔
265
                return nil, ErrNoOngoingTx
1✔
266
        }
1✔
267

268
        return nil, tx.Commit(ctx)
158✔
269
}
270

271
type RollbackStmt struct {
272
}
273

274
func (stmt *RollbackStmt) readOnly() bool {
1✔
275
        return true
1✔
276
}
1✔
277

278
func (stmt *RollbackStmt) requiredPrivileges() []SQLPrivilege {
1✔
279
        return nil
1✔
280
}
1✔
281

282
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
283
        return nil
1✔
284
}
1✔
285

286
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
287
        if !tx.IsExplicitCloseRequired() {
5✔
288
                return nil, ErrNoOngoingTx
1✔
289
        }
1✔
290

291
        return nil, tx.Cancel()
3✔
292
}
293

294
type CreateDatabaseStmt struct {
295
        DB          string
296
        ifNotExists bool
297
}
298

299
func (stmt *CreateDatabaseStmt) readOnly() bool {
14✔
300
        return false
14✔
301
}
14✔
302

303
func (stmt *CreateDatabaseStmt) requiredPrivileges() []SQLPrivilege {
14✔
304
        return []SQLPrivilege{SQLPrivilegeCreate}
14✔
305
}
14✔
306

307
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
308
        return nil
4✔
309
}
4✔
310

311
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
312
        if tx.IsExplicitCloseRequired() {
16✔
313
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
314
        }
1✔
315

316
        if tx.engine.multidbHandler == nil {
16✔
317
                return nil, ErrUnspecifiedMultiDBHandler
2✔
318
        }
2✔
319

320
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
321
}
322

323
type UseDatabaseStmt struct {
324
        DB string
325
}
326

327
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
328
        return nil
1✔
329
}
1✔
330

331
func (stmt *UseDatabaseStmt) readOnly() bool {
9✔
332
        return true
9✔
333
}
9✔
334

335
func (stmt *UseDatabaseStmt) requiredPrivileges() []SQLPrivilege {
9✔
336
        return nil
9✔
337
}
9✔
338

339
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
340
        if tx.IsExplicitCloseRequired() {
10✔
341
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
342
        }
1✔
343

344
        if tx.engine.multidbHandler == nil {
9✔
345
                return nil, ErrUnspecifiedMultiDBHandler
1✔
346
        }
1✔
347

348
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
349
}
350

351
type UseSnapshotStmt struct {
352
        period period
353
}
354

355
func (stmt *UseSnapshotStmt) readOnly() bool {
1✔
356
        return true
1✔
357
}
1✔
358

359
func (stmt *UseSnapshotStmt) requiredPrivileges() []SQLPrivilege {
1✔
360
        return nil
1✔
361
}
1✔
362

363
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
364
        return nil
1✔
365
}
1✔
366

367
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
368
        return nil, ErrNoSupported
1✔
369
}
1✔
370

371
type CreateUserStmt struct {
372
        username   string
373
        password   string
374
        permission Permission
375
}
376

377
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
378
        return nil
1✔
379
}
1✔
380

381
func (stmt *CreateUserStmt) readOnly() bool {
5✔
382
        return false
5✔
383
}
5✔
384

385
func (stmt *CreateUserStmt) requiredPrivileges() []SQLPrivilege {
5✔
386
        return []SQLPrivilege{SQLPrivilegeCreate}
5✔
387
}
5✔
388

389
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
390
        if tx.IsExplicitCloseRequired() {
7✔
391
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
392
        }
1✔
393

394
        if tx.engine.multidbHandler == nil {
6✔
395
                return nil, ErrUnspecifiedMultiDBHandler
1✔
396
        }
1✔
397

398
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
399
}
400

401
type AlterUserStmt struct {
402
        username   string
403
        password   string
404
        permission Permission
405
}
406

407
func (stmt *AlterUserStmt) readOnly() bool {
4✔
408
        return false
4✔
409
}
4✔
410

411
func (stmt *AlterUserStmt) requiredPrivileges() []SQLPrivilege {
4✔
412
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
413
}
4✔
414

415
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
416
        return nil
1✔
417
}
1✔
418

419
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
420
        if tx.IsExplicitCloseRequired() {
6✔
421
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
422
        }
1✔
423

424
        if tx.engine.multidbHandler == nil {
5✔
425
                return nil, ErrUnspecifiedMultiDBHandler
1✔
426
        }
1✔
427

428
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
429
}
430

431
type DropUserStmt struct {
432
        username string
433
}
434

435
func (stmt *DropUserStmt) readOnly() bool {
3✔
436
        return false
3✔
437
}
3✔
438

439
func (stmt *DropUserStmt) requiredPrivileges() []SQLPrivilege {
3✔
440
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
441
}
3✔
442

443
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
444
        return nil
1✔
445
}
1✔
446

447
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
448
        if tx.IsExplicitCloseRequired() {
4✔
449
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
450
        }
1✔
451

452
        if tx.engine.multidbHandler == nil {
3✔
453
                return nil, ErrUnspecifiedMultiDBHandler
1✔
454
        }
1✔
455

456
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
457
}
458

459
type TableElem interface{}
460

461
type CreateTableStmt struct {
462
        table       string
463
        ifNotExists bool
464
        colsSpec    []*ColSpec
465
        checks      []CheckConstraint
466
        pkColNames  PrimaryKeyConstraint
467
}
468

469
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
470
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
471
}
38✔
472

473
func (stmt *CreateTableStmt) readOnly() bool {
66✔
474
        return false
66✔
475
}
66✔
476

477
func (stmt *CreateTableStmt) requiredPrivileges() []SQLPrivilege {
64✔
478
        return []SQLPrivilege{SQLPrivilegeCreate}
64✔
479
}
64✔
480

481
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
482
        return nil
4✔
483
}
4✔
484

485
func zeroRow(tableName string, cols []*ColSpec) *Row {
244✔
486
        r := Row{
244✔
487
                ValuesByPosition: make([]TypedValue, len(cols)),
244✔
488
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
244✔
489
        }
244✔
490

244✔
491
        for i, col := range cols {
1,024✔
492
                v := zeroForType(col.colType)
780✔
493

780✔
494
                r.ValuesByPosition[i] = v
780✔
495
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
780✔
496
        }
780✔
497
        return &r
244✔
498
}
499

500
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
235✔
501
        if err := stmt.validatePrimaryKey(); err != nil {
238✔
502
                return nil, err
3✔
503
        }
3✔
504

505
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
233✔
506
                return tx, nil
1✔
507
        }
1✔
508

509
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
231✔
510
        for i, cs := range stmt.colsSpec {
951✔
511
                colSpecs[uint32(i)+1] = cs
720✔
512
        }
720✔
513

514
        row := zeroRow(stmt.table, stmt.colsSpec)
231✔
515
        for _, check := range stmt.checks {
240✔
516
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
517
                if err != nil {
11✔
518
                        return nil, err
2✔
519
                }
2✔
520

521
                if value.Type() != BooleanType {
7✔
522
                        return nil, ErrInvalidCheckConstraint
×
523
                }
×
524
        }
525

526
        nextUnnamedCheck := 0
229✔
527
        checks := make(map[string]CheckConstraint)
229✔
528
        for id, check := range stmt.checks {
236✔
529
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
530
                if check.name != "" {
9✔
531
                        name = check.name
2✔
532
                } else {
7✔
533
                        nextUnnamedCheck++
5✔
534
                }
5✔
535
                check.id = uint32(id)
7✔
536
                check.name = name
7✔
537
                checks[name] = check
7✔
538
        }
539

540
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
229✔
541
        if err != nil {
235✔
542
                return nil, err
6✔
543
        }
6✔
544

545
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.primaryKeyCols()}
223✔
546
        _, err = createIndexStmt.execAt(ctx, tx, params)
223✔
547
        if err != nil {
228✔
548
                return nil, err
5✔
549
        }
5✔
550

551
        for _, col := range table.cols {
919✔
552
                if col.autoIncrement {
778✔
553
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
78✔
554
                                return nil, ErrLimitedAutoIncrement
1✔
555
                        }
1✔
556
                }
557

558
                err := persistColumn(tx, col)
700✔
559
                if err != nil {
700✔
560
                        return nil, err
×
561
                }
×
562
        }
563

564
        for _, check := range checks {
224✔
565
                if err := persistCheck(tx, table, &check); err != nil {
7✔
566
                        return nil, err
×
567
                }
×
568
        }
569

570
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
217✔
571

217✔
572
        err = tx.set(mappedKey, nil, []byte(table.name))
217✔
573
        if err != nil {
217✔
574
                return nil, err
×
575
        }
×
576

577
        tx.mutatedCatalog = true
217✔
578

217✔
579
        return tx, nil
217✔
580
}
581

582
func (stmt *CreateTableStmt) validatePrimaryKey() error {
235✔
583
        n := 0
235✔
584
        for _, spec := range stmt.colsSpec {
962✔
585
                if spec.primaryKey {
732✔
586
                        n++
5✔
587
                }
5✔
588
        }
589

590
        if len(stmt.pkColNames) > 0 {
466✔
591
                n++
231✔
592
        }
231✔
593

594
        switch n {
235✔
595
        case 0:
1✔
596
                return ErrNoPrimaryKey
1✔
597
        case 1:
232✔
598
                return nil
232✔
599
        }
600
        return fmt.Errorf("\"%s\": %w", stmt.table, ErrMultiplePrimaryKeys)
2✔
601
}
602

603
func (stmt *CreateTableStmt) primaryKeyCols() []string {
223✔
604
        if len(stmt.pkColNames) > 0 {
444✔
605
                return stmt.pkColNames
221✔
606
        }
221✔
607

608
        for _, spec := range stmt.colsSpec {
4✔
609
                if spec.primaryKey {
4✔
610
                        return []string{spec.colName}
2✔
611
                }
2✔
612
        }
613
        return nil
×
614
}
615

616
func persistColumn(tx *SQLTx, col *Column) error {
720✔
617
        //{auto_incremental | nullable}{maxLen}{colNAME})
720✔
618
        v := make([]byte, 1+4+len(col.colName))
720✔
619

720✔
620
        if col.autoIncrement {
796✔
621
                v[0] = v[0] | autoIncrementFlag
76✔
622
        }
76✔
623

624
        if col.notNull {
767✔
625
                v[0] = v[0] | nullableFlag
47✔
626
        }
47✔
627

628
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
720✔
629

720✔
630
        copy(v[5:], []byte(col.Name()))
720✔
631

720✔
632
        mappedKey := MapKey(
720✔
633
                tx.sqlPrefix(),
720✔
634
                catalogColumnPrefix,
720✔
635
                EncodeID(DatabaseID),
720✔
636
                EncodeID(col.table.id),
720✔
637
                EncodeID(col.id),
720✔
638
                []byte(col.colType),
720✔
639
        )
720✔
640

720✔
641
        return tx.set(mappedKey, nil, v)
720✔
642
}
643

644
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
645
        mappedKey := MapKey(
7✔
646
                tx.sqlPrefix(),
7✔
647
                catalogCheckPrefix,
7✔
648
                EncodeID(DatabaseID),
7✔
649
                EncodeID(table.id),
7✔
650
                EncodeID(check.id),
7✔
651
        )
7✔
652

7✔
653
        name := check.name
7✔
654
        expText := check.exp.String()
7✔
655

7✔
656
        val := make([]byte, 2+len(name)+len(expText))
7✔
657

7✔
658
        if len(name) > 256 {
7✔
659
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
660
        }
×
661

662
        val[0] = byte(len(name)) - 1
7✔
663

7✔
664
        copy(val[1:], []byte(name))
7✔
665
        copy(val[1+len(name):], []byte(expText))
7✔
666

7✔
667
        return tx.set(mappedKey, nil, val)
7✔
668
}
669

670
type ColSpec struct {
671
        colName       string
672
        colType       SQLValueType
673
        maxLen        int
674
        autoIncrement bool
675
        notNull       bool
676
        primaryKey    bool
677
}
678

679
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
680
        return &ColSpec{
188✔
681
                colName:       name,
188✔
682
                colType:       colType,
188✔
683
                maxLen:        maxLen,
188✔
684
                autoIncrement: autoIncrement,
188✔
685
                notNull:       notNull,
188✔
686
        }
188✔
687
}
188✔
688

689
type CreateIndexStmt struct {
690
        unique      bool
691
        ifNotExists bool
692
        table       string
693
        cols        []string
694
}
695

696
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
697
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
698
}
72✔
699

700
func (stmt *CreateIndexStmt) readOnly() bool {
7✔
701
        return false
7✔
702
}
7✔
703

704
func (stmt *CreateIndexStmt) requiredPrivileges() []SQLPrivilege {
7✔
705
        return []SQLPrivilege{SQLPrivilegeCreate}
7✔
706
}
7✔
707

708
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
709
        return nil
1✔
710
}
1✔
711

712
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
374✔
713
        if len(stmt.cols) < 1 {
375✔
714
                return nil, ErrIllegalArguments
1✔
715
        }
1✔
716

717
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
374✔
718
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
719
        }
1✔
720

721
        table, err := tx.catalog.GetTableByName(stmt.table)
372✔
722
        if err != nil {
374✔
723
                return nil, err
2✔
724
        }
2✔
725

726
        colIDs := make([]uint32, len(stmt.cols))
370✔
727

370✔
728
        indexKeyLen := 0
370✔
729

370✔
730
        for i, colName := range stmt.cols {
767✔
731
                col, err := table.GetColumnByName(colName)
397✔
732
                if err != nil {
402✔
733
                        return nil, err
5✔
734
                }
5✔
735

736
                if col.Type() == JSONType {
394✔
737
                        return nil, ErrCannotIndexJson
2✔
738
                }
2✔
739

740
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
392✔
741
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
742
                }
2✔
743

744
                indexKeyLen += col.MaxLen()
388✔
745

388✔
746
                colIDs[i] = col.id
388✔
747
        }
748

749
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
361✔
750
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
751
        }
×
752

753
        if stmt.unique && table.primaryIndex != nil {
381✔
754
                // check table is empty
20✔
755
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
756
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
757
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
758
                        return nil, ErrTableDoesNotExist
×
759
                }
×
760
                if err == nil {
21✔
761
                        return nil, ErrLimitedIndexCreation
1✔
762
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
763
                        return nil, err
×
764
                }
×
765
        }
766

767
        index, err := table.newIndex(stmt.unique, colIDs)
360✔
768
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
362✔
769
                return tx, nil
2✔
770
        }
2✔
771
        if err != nil {
362✔
772
                return nil, err
4✔
773
        }
4✔
774

775
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
776
        // TODO: currently only ASC order is supported
777
        colSpecLen := EncIDLen + 1
354✔
778

354✔
779
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
354✔
780

354✔
781
        if index.IsUnique() {
590✔
782
                encodedValues[0] = 1
236✔
783
        }
236✔
784

785
        for i, col := range index.cols {
735✔
786
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
381✔
787
        }
381✔
788

789
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
354✔
790

354✔
791
        err = tx.set(mappedKey, nil, encodedValues)
354✔
792
        if err != nil {
354✔
793
                return nil, err
×
794
        }
×
795

796
        tx.mutatedCatalog = true
354✔
797

354✔
798
        return tx, nil
354✔
799
}
800

801
type AddColumnStmt struct {
802
        table   string
803
        colSpec *ColSpec
804
}
805

806
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
807
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
808
}
6✔
809

810
func (stmt *AddColumnStmt) readOnly() bool {
4✔
811
        return false
4✔
812
}
4✔
813

814
func (stmt *AddColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
815
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
816
}
4✔
817

818
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
819
        return nil
1✔
820
}
1✔
821

822
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
823
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
824
        if err != nil {
20✔
825
                return nil, err
1✔
826
        }
1✔
827

828
        col, err := table.newColumn(stmt.colSpec)
18✔
829
        if err != nil {
24✔
830
                return nil, err
6✔
831
        }
6✔
832

833
        err = persistColumn(tx, col)
12✔
834
        if err != nil {
12✔
835
                return nil, err
×
836
        }
×
837

838
        tx.mutatedCatalog = true
12✔
839

12✔
840
        return tx, nil
12✔
841
}
842

843
type RenameTableStmt struct {
844
        oldName string
845
        newName string
846
}
847

848
func (stmt *RenameTableStmt) readOnly() bool {
1✔
849
        return false
1✔
850
}
1✔
851

852
func (stmt *RenameTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
853
        return []SQLPrivilege{SQLPrivilegeAlter}
1✔
854
}
1✔
855

856
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
857
        return nil
1✔
858
}
1✔
859

860
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
861
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
862
        if err != nil {
10✔
863
                return nil, err
4✔
864
        }
4✔
865

866
        // update table name
867
        mappedKey := MapKey(
2✔
868
                tx.sqlPrefix(),
2✔
869
                catalogTablePrefix,
2✔
870
                EncodeID(DatabaseID),
2✔
871
                EncodeID(table.id),
2✔
872
        )
2✔
873
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
874
        if err != nil {
2✔
875
                return nil, err
×
876
        }
×
877

878
        tx.mutatedCatalog = true
2✔
879

2✔
880
        return tx, nil
2✔
881
}
882

883
type RenameColumnStmt struct {
884
        table   string
885
        oldName string
886
        newName string
887
}
888

889
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
890
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
891
}
3✔
892

893
func (stmt *RenameColumnStmt) readOnly() bool {
4✔
894
        return false
4✔
895
}
4✔
896

897
func (stmt *RenameColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
898
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
899
}
4✔
900

901
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
902
        return nil
1✔
903
}
1✔
904

905
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
906
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
907
        if err != nil {
11✔
908
                return nil, err
1✔
909
        }
1✔
910

911
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
912
        if err != nil {
12✔
913
                return nil, err
3✔
914
        }
3✔
915

916
        err = persistColumn(tx, col)
6✔
917
        if err != nil {
6✔
918
                return nil, err
×
919
        }
×
920

921
        tx.mutatedCatalog = true
6✔
922

6✔
923
        return tx, nil
6✔
924
}
925

926
type DropColumnStmt struct {
927
        table   string
928
        colName string
929
}
930

931
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
932
        return &DropColumnStmt{table: table, colName: colName}
8✔
933
}
8✔
934

935
func (stmt *DropColumnStmt) readOnly() bool {
2✔
936
        return false
2✔
937
}
2✔
938

939
func (stmt *DropColumnStmt) requiredPrivileges() []SQLPrivilege {
2✔
940
        return []SQLPrivilege{SQLPrivilegeDrop}
2✔
941
}
2✔
942

943
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
944
        return nil
1✔
945
}
1✔
946

947
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
948
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
949
        if err != nil {
21✔
950
                return nil, err
2✔
951
        }
2✔
952

953
        col, err := table.GetColumnByName(stmt.colName)
17✔
954
        if err != nil {
21✔
955
                return nil, err
4✔
956
        }
4✔
957

958
        err = canDropColumn(tx, table, col)
13✔
959
        if err != nil {
14✔
960
                return nil, err
1✔
961
        }
1✔
962

963
        err = table.deleteColumn(col)
12✔
964
        if err != nil {
16✔
965
                return nil, err
4✔
966
        }
4✔
967

968
        err = persistColumnDeletion(ctx, tx, col)
8✔
969
        if err != nil {
8✔
970
                return nil, err
×
971
        }
×
972

973
        tx.mutatedCatalog = true
8✔
974

8✔
975
        return tx, nil
8✔
976
}
977

978
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
979
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
980
        for _, c := range table.cols {
86✔
981
                if c.id != col.id {
133✔
982
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
983
                }
60✔
984
        }
985

986
        row := zeroRow(table.Name(), colSpecs)
13✔
987
        for name, check := range table.checkConstraints {
20✔
988
                _, err := check.exp.reduce(tx, row, table.name)
7✔
989
                if errors.Is(err, ErrColumnDoesNotExist) {
8✔
990
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
991
                }
1✔
992

993
                if err != nil {
6✔
994
                        return err
×
995
                }
×
996
        }
997
        return nil
12✔
998
}
999

1000
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
1001
        mappedKey := MapKey(
9✔
1002
                tx.sqlPrefix(),
9✔
1003
                catalogColumnPrefix,
9✔
1004
                EncodeID(DatabaseID),
9✔
1005
                EncodeID(col.table.id),
9✔
1006
                EncodeID(col.id),
9✔
1007
                []byte(col.colType),
9✔
1008
        )
9✔
1009

9✔
1010
        return tx.delete(ctx, mappedKey)
9✔
1011
}
9✔
1012

1013
type DropConstraintStmt struct {
1014
        table          string
1015
        constraintName string
1016
}
1017

1018
func (stmt *DropConstraintStmt) readOnly() bool {
×
1019
        return false
×
1020
}
×
1021

1022
func (stmt *DropConstraintStmt) requiredPrivileges() []SQLPrivilege {
×
1023
        return []SQLPrivilege{SQLPrivilegeDrop}
×
1024
}
×
1025

1026
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
1027
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
1028
        if err != nil {
4✔
1029
                return nil, err
×
1030
        }
×
1031

1032
        id, err := table.deleteCheck(stmt.constraintName)
4✔
1033
        if err != nil {
5✔
1034
                return nil, err
1✔
1035
        }
1✔
1036

1037
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
1038

3✔
1039
        tx.mutatedCatalog = true
3✔
1040

3✔
1041
        return tx, err
3✔
1042
}
1043

1044
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
1045
        mappedKey := MapKey(
3✔
1046
                tx.sqlPrefix(),
3✔
1047
                catalogCheckPrefix,
3✔
1048
                EncodeID(DatabaseID),
3✔
1049
                EncodeID(tableID),
3✔
1050
                EncodeID(checkId),
3✔
1051
        )
3✔
1052
        return tx.delete(ctx, mappedKey)
3✔
1053
}
3✔
1054

1055
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
1056
        return nil
×
1057
}
×
1058

1059
type UpsertIntoStmt struct {
1060
        isInsert   bool
1061
        tableRef   *tableRef
1062
        cols       []string
1063
        ds         DataSource
1064
        onConflict *OnConflictDo
1065
}
1066

1067
func (stmt *UpsertIntoStmt) readOnly() bool {
101✔
1068
        return false
101✔
1069
}
101✔
1070

1071
func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege {
101✔
1072
        privileges := stmt.privileges()
101✔
1073
        if stmt.ds != nil {
200✔
1074
                privileges = append(privileges, stmt.ds.requiredPrivileges()...)
99✔
1075
        }
99✔
1076
        return privileges
101✔
1077
}
1078

1079
func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege {
101✔
1080
        if stmt.isInsert {
190✔
1081
                return []SQLPrivilege{SQLPrivilegeInsert}
89✔
1082
        }
89✔
1083
        return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate}
12✔
1084
}
1085

1086
func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
1087
        return &UpsertIntoStmt{
120✔
1088
                isInsert:   isInsert,
120✔
1089
                tableRef:   NewTableRef(table, ""),
120✔
1090
                cols:       cols,
120✔
1091
                ds:         ds,
120✔
1092
                onConflict: onConflict,
120✔
1093
        }
120✔
1094
}
120✔
1095

1096
type RowSpec struct {
1097
        Values []ValueExp
1098
}
1099

1100
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
1101
        return &RowSpec{
129✔
1102
                Values: values,
129✔
1103
        }
129✔
1104
}
129✔
1105

1106
type OnConflictDo struct{}
1107

1108
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
1109
        ds, ok := stmt.ds.(*valuesDataSource)
11✔
1110
        if !ok {
11✔
1111
                return stmt.ds.inferParameters(ctx, tx, params)
×
1112
        }
×
1113

1114
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
1115
        for _, row := range ds.rows {
23✔
1116
                if len(stmt.cols) != len(row.Values) {
13✔
1117
                        return ErrInvalidNumberOfValues
1✔
1118
                }
1✔
1119

1120
                for i, val := range row.Values {
36✔
1121
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
1122
                        if err != nil {
26✔
1123
                                return err
1✔
1124
                        }
1✔
1125

1126
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
1127
                        if err != nil {
25✔
1128
                                return err
1✔
1129
                        }
1✔
1130

1131
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
1132
                        if err != nil {
25✔
1133
                                return err
2✔
1134
                        }
2✔
1135
                }
1136
        }
1137
        return nil
6✔
1138
}
1139

1140
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
2,115✔
1141
        selPosByColID := make(map[uint32]int, len(stmt.cols))
2,115✔
1142

2,115✔
1143
        for i, c := range stmt.cols {
9,736✔
1144
                col, err := table.GetColumnByName(c)
7,621✔
1145
                if err != nil {
7,623✔
1146
                        return nil, err
2✔
1147
                }
2✔
1148

1149
                _, duplicated := selPosByColID[col.id]
7,619✔
1150
                if duplicated {
7,620✔
1151
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
1152
                }
1✔
1153

1154
                selPosByColID[col.id] = i
7,618✔
1155
        }
1156

1157
        return selPosByColID, nil
2,112✔
1158
}
1159

1160
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2,118✔
1161
        table, err := stmt.tableRef.referencedTable(tx)
2,118✔
1162
        if err != nil {
2,121✔
1163
                return nil, err
3✔
1164
        }
3✔
1165

1166
        selPosByColID, err := stmt.validate(table)
2,115✔
1167
        if err != nil {
2,118✔
1168
                return nil, err
3✔
1169
        }
3✔
1170

1171
        r := &Row{
2,112✔
1172
                ValuesByPosition: make([]TypedValue, len(table.cols)),
2,112✔
1173
                ValuesBySelector: make(map[string]TypedValue),
2,112✔
1174
        }
2,112✔
1175

2,112✔
1176
        reader, err := stmt.ds.Resolve(ctx, tx, params, nil)
2,112✔
1177
        if err != nil {
2,112✔
1178
                return nil, err
×
1179
        }
×
1180
        defer reader.Close()
2,112✔
1181

2,112✔
1182
        for {
6,490✔
1183
                row, err := reader.Read(ctx)
4,378✔
1184
                if errors.Is(err, ErrNoMoreRows) {
6,447✔
1185
                        break
2,069✔
1186
                }
1187
                if err != nil {
2,319✔
1188
                        return nil, err
10✔
1189
                }
10✔
1190

1191
                if len(row.ValuesByPosition) != len(stmt.cols) {
2,301✔
1192
                        return nil, ErrInvalidNumberOfValues
2✔
1193
                }
2✔
1194

1195
                valuesByColID := make(map[uint32]TypedValue)
2,297✔
1196

2,297✔
1197
                var pkMustExist bool
2,297✔
1198

2,297✔
1199
                for colID, col := range table.colsByID {
11,742✔
1200
                        colPos, specified := selPosByColID[colID]
9,445✔
1201
                        if !specified {
10,691✔
1202
                                // TODO: Default values
1,246✔
1203
                                if col.notNull && !col.autoIncrement {
1,247✔
1204
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
1205
                                }
1✔
1206

1207
                                // inject auto-incremental pk value
1208
                                if stmt.isInsert && col.autoIncrement {
2,303✔
1209
                                        // current implementation assumes only PK can be set as autoincremental
1,058✔
1210
                                        table.maxPK++
1,058✔
1211

1,058✔
1212
                                        pkCol := table.primaryIndex.cols[0]
1,058✔
1213
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
1,058✔
1214

1,058✔
1215
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,924✔
1216
                                                tx.firstInsertedPKs[table.name] = table.maxPK
866✔
1217
                                        }
866✔
1218
                                        tx.lastInsertedPKs[table.name] = table.maxPK
1,058✔
1219
                                }
1220

1221
                                continue
1,245✔
1222
                        }
1223

1224
                        // value was specified
1225
                        cVal := row.ValuesByPosition[colPos]
8,199✔
1226

8,199✔
1227
                        val, err := cVal.substitute(params)
8,199✔
1228
                        if err != nil {
8,199✔
1229
                                return nil, err
×
1230
                        }
×
1231

1232
                        rval, err := val.reduce(tx, nil, table.name)
8,199✔
1233
                        if err != nil {
8,199✔
1234
                                return nil, err
×
1235
                        }
×
1236

1237
                        if rval.IsNull() {
8,297✔
1238
                                if col.notNull || col.autoIncrement {
98✔
1239
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1240
                                }
×
1241

1242
                                continue
98✔
1243
                        }
1244

1245
                        if col.autoIncrement {
8,120✔
1246
                                // validate specified value
19✔
1247
                                nl, isNumber := rval.RawValue().(int64)
19✔
1248
                                if !isNumber {
19✔
1249
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1250
                                }
×
1251

1252
                                pkMustExist = nl <= table.maxPK
19✔
1253

19✔
1254
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1255
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1256
                                }
19✔
1257
                                tx.lastInsertedPKs[table.name] = nl
19✔
1258
                        }
1259

1260
                        valuesByColID[colID] = rval
8,101✔
1261
                }
1262

1263
                for i, col := range table.cols {
11,739✔
1264
                        v := valuesByColID[col.id]
9,443✔
1265

9,443✔
1266
                        if v == nil {
9,727✔
1267
                                v = NewNull(AnyType)
284✔
1268
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
9,448✔
1269
                                s, _ := v.RawValue().(string)
5✔
1270
                                jsonVal, err := NewJsonFromString(s)
5✔
1271
                                if err != nil {
5✔
1272
                                        return nil, err
×
1273
                                }
×
1274
                                v = jsonVal
5✔
1275
                        }
1276

1277
                        r.ValuesByPosition[i] = v
9,443✔
1278
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
9,443✔
1279
                }
1280

1281
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
2,302✔
1282
                        return nil, err
6✔
1283
                }
6✔
1284

1285
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
2,290✔
1286
                if err != nil {
2,295✔
1287
                        return nil, err
5✔
1288
                }
5✔
1289

1290
                // pk entry
1291
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
2,285✔
1292
                if len(mappedPKey) > MaxKeyLen {
2,285✔
1293
                        return nil, ErrMaxKeyLengthExceeded
×
1294
                }
×
1295

1296
                _, err = tx.get(ctx, mappedPKey)
2,285✔
1297
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
2,285✔
1298
                        return nil, err
×
1299
                }
×
1300

1301
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
2,287✔
1302
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1303
                }
2✔
1304

1305
                if stmt.isInsert {
4,385✔
1306
                        if err == nil && stmt.onConflict == nil {
2,106✔
1307
                                return nil, store.ErrKeyAlreadyExists
4✔
1308
                        }
4✔
1309

1310
                        if err == nil && stmt.onConflict != nil {
2,101✔
1311
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1312
                                continue
3✔
1313
                        }
1314
                }
1315

1316
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
2,276✔
1317
                if err != nil {
2,289✔
1318
                        return nil, err
13✔
1319
                }
13✔
1320
        }
1321
        return tx, nil
2,069✔
1322
}
1323

1324
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
2,329✔
1325
        for _, check := range checks {
2,373✔
1326
                val, err := check.exp.reduce(tx, row, table)
44✔
1327
                if err != nil {
45✔
1328
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1329
                }
1✔
1330

1331
                if val.Type() != BooleanType {
43✔
1332
                        return ErrInvalidCheckConstraint
×
1333
                }
×
1334

1335
                if !val.RawValue().(bool) {
50✔
1336
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1337
                }
7✔
1338
        }
1339
        return nil
2,321✔
1340
}
1341

1342
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,464✔
1343
        valbuf := bytes.Buffer{}
2,464✔
1344

2,464✔
1345
        // null values are not serialized
2,464✔
1346
        encodedVals := 0
2,464✔
1347
        for _, v := range valuesByColID {
12,042✔
1348
                if !v.IsNull() {
19,138✔
1349
                        encodedVals++
9,560✔
1350
                }
9,560✔
1351
        }
1352

1353
        b := make([]byte, EncLenLen)
2,464✔
1354
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,464✔
1355

2,464✔
1356
        _, err := valbuf.Write(b)
2,464✔
1357
        if err != nil {
2,464✔
1358
                return nil, err
×
1359
        }
×
1360

1361
        for _, col := range table.cols {
12,302✔
1362
                rval, specified := valuesByColID[col.id]
9,838✔
1363
                if !specified || rval.IsNull() {
10,122✔
1364
                        continue
284✔
1365
                }
1366

1367
                b := make([]byte, EncIDLen)
9,554✔
1368
                binary.BigEndian.PutUint32(b, uint32(col.id))
9,554✔
1369

9,554✔
1370
                _, err = valbuf.Write(b)
9,554✔
1371
                if err != nil {
9,554✔
1372
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1373
                }
×
1374

1375
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
9,554✔
1376
                if err != nil {
9,562✔
1377
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1378
                }
8✔
1379

1380
                _, err = valbuf.Write(encVal)
9,546✔
1381
                if err != nil {
9,546✔
1382
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1383
                }
×
1384
        }
1385

1386
        return valbuf.Bytes(), nil
2,456✔
1387
}
1388

1389
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
2,307✔
1390
        var reusableIndexEntries map[uint32]struct{}
2,307✔
1391

2,307✔
1392
        if reuseIndex && len(table.indexes) > 1 {
2,364✔
1393
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1394
                if err == nil {
93✔
1395
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1396

36✔
1397
                        for _, col := range table.cols {
161✔
1398
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1399
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1400
                        }
125✔
1401

1402
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1403
                        if err != nil {
36✔
1404
                                return err
×
1405
                        }
×
1406
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1407
                        return err
×
1408
                }
×
1409
        }
1410

1411
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
2,307✔
1412

2,307✔
1413
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
2,307✔
1414
        if err != nil {
2,315✔
1415
                return err
8✔
1416
        }
8✔
1417

1418
        err = tx.set(rowKey, nil, encodedRowValue)
2,299✔
1419
        if err != nil {
2,299✔
1420
                return err
×
1421
        }
×
1422

1423
        // create in-memory and validate entries for secondary indexes
1424
        for _, index := range table.indexes {
5,493✔
1425
                if index.IsPrimary() {
5,493✔
1426
                        continue
2,299✔
1427
                }
1428

1429
                if reusableIndexEntries != nil {
972✔
1430
                        _, reusable := reusableIndexEntries[index.id]
77✔
1431
                        if reusable {
127✔
1432
                                continue
50✔
1433
                        }
1434
                }
1435

1436
                encodedValues := make([][]byte, 2+len(index.cols))
845✔
1437
                encodedValues[0] = EncodeID(table.id)
845✔
1438
                encodedValues[1] = EncodeID(index.id)
845✔
1439

845✔
1440
                indexKeyLen := 0
845✔
1441

845✔
1442
                for i, col := range index.cols {
1,765✔
1443
                        rval, specified := valuesByColID[col.id]
920✔
1444
                        if !specified {
993✔
1445
                                rval = &NullValue{t: col.colType}
73✔
1446
                        }
73✔
1447

1448
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
920✔
1449
                        if err != nil {
920✔
1450
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1451
                        }
×
1452

1453
                        if n > MaxKeyLen {
920✔
1454
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1455
                        }
×
1456

1457
                        indexKeyLen += n
920✔
1458

920✔
1459
                        encodedValues[i+2] = encVal
920✔
1460
                }
1461

1462
                if indexKeyLen > MaxKeyLen {
845✔
1463
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1464
                }
×
1465

1466
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
845✔
1467

845✔
1468
                // no other equivalent entry should be already indexed
845✔
1469
                if index.IsUnique() {
922✔
1470
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1471
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1472
                                return store.ErrKeyAlreadyExists
5✔
1473
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1474
                                return err
×
1475
                        }
×
1476
                }
1477

1478
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
840✔
1479
                if err != nil {
840✔
1480
                        return err
×
1481
                }
×
1482
        }
1483

1484
        tx.updatedRows++
2,294✔
1485

2,294✔
1486
        return nil
2,294✔
1487
}
1488

1489
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
13,974✔
1490
        valbuf := bytes.Buffer{}
13,974✔
1491

13,974✔
1492
        indexKeyLen := 0
13,974✔
1493

13,974✔
1494
        for _, col := range index.cols {
27,960✔
1495
                rval, specified := valuesByColID[col.id]
13,986✔
1496
                if !specified || rval.IsNull() {
13,989✔
1497
                        return nil, ErrPKCanNotBeNull
3✔
1498
                }
3✔
1499

1500
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
13,983✔
1501
                if err != nil {
13,985✔
1502
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1503
                }
2✔
1504

1505
                if n > MaxKeyLen {
13,981✔
1506
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1507
                }
×
1508

1509
                indexKeyLen += n
13,981✔
1510

13,981✔
1511
                _, err = valbuf.Write(encVal)
13,981✔
1512
                if err != nil {
13,981✔
1513
                        return nil, err
×
1514
                }
×
1515
        }
1516

1517
        if indexKeyLen > MaxKeyLen {
13,969✔
1518
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1519
        }
×
1520

1521
        return valbuf.Bytes(), nil
13,969✔
1522
}
1523

1524
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1525
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1526

57✔
1527
        for _, pkCol := range table.primaryIndex.cols {
114✔
1528
                pkVal := valuesByColID[pkCol.id]
57✔
1529

57✔
1530
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1531
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1532
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1533
                }
57✔
1534
        }
57✔
1535

1536
        scanSpecs := &ScanSpecs{
57✔
1537
                Index:         table.primaryIndex,
57✔
1538
                rangesByColID: pkRanges,
57✔
1539
        }
57✔
1540

57✔
1541
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1542
        if err != nil {
57✔
1543
                return nil, err
×
1544
        }
×
1545

1546
        defer func() {
114✔
1547
                r.Close()
57✔
1548
        }()
57✔
1549

1550
        return r.Read(ctx)
57✔
1551
}
1552

1553
// deprecateIndexEntries mark previous index entries as deleted
1554
func (tx *SQLTx) deprecateIndexEntries(
1555
        pkEncVals []byte,
1556
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1557
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1558

36✔
1559
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1560
        if err != nil {
36✔
1561
                return nil, err
×
1562
        }
×
1563

1564
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1565

36✔
1566
        for _, index := range table.indexes {
149✔
1567
                if index.IsPrimary() {
149✔
1568
                        continue
36✔
1569
                }
1570

1571
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1572
                encodedValues[0] = EncodeID(table.id)
77✔
1573
                encodedValues[1] = EncodeID(index.id)
77✔
1574
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1575

77✔
1576
                // existent index entry is deleted only if it differs from existent one
77✔
1577
                sameIndexKey := true
77✔
1578

77✔
1579
                for i, col := range index.cols {
159✔
1580
                        currVal, specified := currValuesByColID[col.id]
82✔
1581
                        if !specified {
82✔
1582
                                currVal = &NullValue{t: col.colType}
×
1583
                        }
×
1584

1585
                        newVal, specified := newValuesByColID[col.id]
82✔
1586
                        if !specified {
86✔
1587
                                newVal = &NullValue{t: col.colType}
4✔
1588
                        }
4✔
1589

1590
                        r, err := currVal.Compare(newVal)
82✔
1591
                        if err != nil {
82✔
1592
                                return nil, err
×
1593
                        }
×
1594

1595
                        sameIndexKey = sameIndexKey && r == 0
82✔
1596

82✔
1597
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1598

82✔
1599
                        encodedValues[i+3] = encVal
82✔
1600
                }
1601

1602
                // mark existent index entry as deleted
1603
                if sameIndexKey {
127✔
1604
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1605
                } else {
77✔
1606
                        md := store.NewKVMetadata()
27✔
1607

27✔
1608
                        md.AsDeleted(true)
27✔
1609

27✔
1610
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1611
                        if err != nil {
27✔
1612
                                return nil, err
×
1613
                        }
×
1614
                }
1615
        }
1616

1617
        return reusableIndexEntries, nil
36✔
1618
}
1619

1620
type UpdateStmt struct {
1621
        tableRef *tableRef
1622
        where    ValueExp
1623
        updates  []*colUpdate
1624
        indexOn  []string
1625
        limit    ValueExp
1626
        offset   ValueExp
1627
}
1628

1629
type colUpdate struct {
1630
        col string
1631
        op  CmpOperator
1632
        val ValueExp
1633
}
1634

1635
func (stmt *UpdateStmt) readOnly() bool {
4✔
1636
        return false
4✔
1637
}
4✔
1638

1639
func (stmt *UpdateStmt) requiredPrivileges() []SQLPrivilege {
4✔
1640
        return []SQLPrivilege{SQLPrivilegeUpdate}
4✔
1641
}
4✔
1642

1643
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1644
        selectStmt := &SelectStmt{
1✔
1645
                ds:    stmt.tableRef,
1✔
1646
                where: stmt.where,
1✔
1647
        }
1✔
1648

1✔
1649
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1650
        if err != nil {
1✔
1651
                return err
×
1652
        }
×
1653

1654
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1655
        if err != nil {
1✔
1656
                return err
×
1657
        }
×
1658

1659
        for _, update := range stmt.updates {
2✔
1660
                col, err := table.GetColumnByName(update.col)
1✔
1661
                if err != nil {
1✔
1662
                        return err
×
1663
                }
×
1664

1665
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1666
                if err != nil {
1✔
1667
                        return err
×
1668
                }
×
1669
        }
1670

1671
        return nil
1✔
1672
}
1673

1674
func (stmt *UpdateStmt) validate(table *Table) error {
21✔
1675
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
21✔
1676

21✔
1677
        for _, update := range stmt.updates {
44✔
1678
                if update.op != EQ {
23✔
1679
                        return ErrIllegalArguments
×
1680
                }
×
1681

1682
                col, err := table.GetColumnByName(update.col)
23✔
1683
                if err != nil {
24✔
1684
                        return err
1✔
1685
                }
1✔
1686

1687
                if table.PrimaryIndex().IncludesCol(col.id) {
22✔
1688
                        return ErrPKCanNotBeUpdated
×
1689
                }
×
1690

1691
                _, duplicated := colIDs[col.id]
22✔
1692
                if duplicated {
22✔
1693
                        return ErrDuplicatedColumn
×
1694
                }
×
1695

1696
                colIDs[col.id] = struct{}{}
22✔
1697
        }
1698

1699
        return nil
20✔
1700
}
1701

1702
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
22✔
1703
        selectStmt := &SelectStmt{
22✔
1704
                ds:      stmt.tableRef,
22✔
1705
                where:   stmt.where,
22✔
1706
                indexOn: stmt.indexOn,
22✔
1707
                limit:   stmt.limit,
22✔
1708
                offset:  stmt.offset,
22✔
1709
        }
22✔
1710

22✔
1711
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
22✔
1712
        if err != nil {
23✔
1713
                return nil, err
1✔
1714
        }
1✔
1715
        defer rowReader.Close()
21✔
1716

21✔
1717
        table := rowReader.ScanSpecs().Index.table
21✔
1718

21✔
1719
        err = stmt.validate(table)
21✔
1720
        if err != nil {
22✔
1721
                return nil, err
1✔
1722
        }
1✔
1723

1724
        cols, err := rowReader.colsBySelector(ctx)
20✔
1725
        if err != nil {
20✔
1726
                return nil, err
×
1727
        }
×
1728

1729
        for {
71✔
1730
                row, err := rowReader.Read(ctx)
51✔
1731
                if errors.Is(err, ErrNoMoreRows) {
68✔
1732
                        break
17✔
1733
                } else if err != nil {
35✔
1734
                        return nil, err
1✔
1735
                }
1✔
1736

1737
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
33✔
1738

33✔
1739
                for _, col := range table.cols {
124✔
1740
                        encSel := EncodeSelector("", table.name, col.colName)
91✔
1741
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
91✔
1742
                }
91✔
1743

1744
                for _, update := range stmt.updates {
68✔
1745
                        col, err := table.GetColumnByName(update.col)
35✔
1746
                        if err != nil {
35✔
1747
                                return nil, err
×
1748
                        }
×
1749

1750
                        sval, err := update.val.substitute(params)
35✔
1751
                        if err != nil {
35✔
1752
                                return nil, err
×
1753
                        }
×
1754

1755
                        rval, err := sval.reduce(tx, row, table.name)
35✔
1756
                        if err != nil {
35✔
1757
                                return nil, err
×
1758
                        }
×
1759

1760
                        err = rval.requiresType(col.colType, cols, nil, table.name)
35✔
1761
                        if err != nil {
35✔
1762
                                return nil, err
×
1763
                        }
×
1764

1765
                        valuesByColID[col.id] = rval
35✔
1766
                }
1767

1768
                for i, col := range table.cols {
124✔
1769
                        v := valuesByColID[col.id]
91✔
1770

91✔
1771
                        row.ValuesByPosition[i] = v
91✔
1772
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
91✔
1773
                }
91✔
1774

1775
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
35✔
1776
                        return nil, err
2✔
1777
                }
2✔
1778

1779
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
31✔
1780
                if err != nil {
31✔
1781
                        return nil, err
×
1782
                }
×
1783

1784
                // primary index entry
1785
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
31✔
1786

31✔
1787
                // mkey must exist
31✔
1788
                _, err = tx.get(ctx, mkey)
31✔
1789
                if err != nil {
31✔
1790
                        return nil, err
×
1791
                }
×
1792

1793
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
31✔
1794
                if err != nil {
31✔
1795
                        return nil, err
×
1796
                }
×
1797
        }
1798

1799
        return tx, nil
17✔
1800
}
1801

1802
type DeleteFromStmt struct {
1803
        tableRef *tableRef
1804
        where    ValueExp
1805
        indexOn  []string
1806
        orderBy  []*OrdExp
1807
        limit    ValueExp
1808
        offset   ValueExp
1809
}
1810

1811
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdExp, limit ValueExp) *DeleteFromStmt {
4✔
1812
        return &DeleteFromStmt{
4✔
1813
                tableRef: NewTableRef(table, ""),
4✔
1814
                where:    where,
4✔
1815
                orderBy:  orderBy,
4✔
1816
                limit:    limit,
4✔
1817
        }
4✔
1818
}
4✔
1819

1820
func (stmt *DeleteFromStmt) readOnly() bool {
1✔
1821
        return false
1✔
1822
}
1✔
1823

1824
func (stmt *DeleteFromStmt) requiredPrivileges() []SQLPrivilege {
1✔
1825
        return []SQLPrivilege{SQLPrivilegeDelete}
1✔
1826
}
1✔
1827

1828
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1829
        selectStmt := &SelectStmt{
1✔
1830
                ds:      stmt.tableRef,
1✔
1831
                where:   stmt.where,
1✔
1832
                orderBy: stmt.orderBy,
1✔
1833
        }
1✔
1834
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1835
}
1✔
1836

1837
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
1838
        selectStmt := &SelectStmt{
15✔
1839
                ds:      stmt.tableRef,
15✔
1840
                where:   stmt.where,
15✔
1841
                indexOn: stmt.indexOn,
15✔
1842
                orderBy: stmt.orderBy,
15✔
1843
                limit:   stmt.limit,
15✔
1844
                offset:  stmt.offset,
15✔
1845
        }
15✔
1846

15✔
1847
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
15✔
1848
        if err != nil {
17✔
1849
                return nil, err
2✔
1850
        }
2✔
1851
        defer rowReader.Close()
13✔
1852

13✔
1853
        table := rowReader.ScanSpecs().Index.table
13✔
1854

13✔
1855
        for {
147✔
1856
                row, err := rowReader.Read(ctx)
134✔
1857
                if errors.Is(err, ErrNoMoreRows) {
146✔
1858
                        break
12✔
1859
                }
1860
                if err != nil {
123✔
1861
                        return nil, err
1✔
1862
                }
1✔
1863

1864
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
121✔
1865

121✔
1866
                for _, col := range table.cols {
406✔
1867
                        encSel := EncodeSelector("", table.name, col.colName)
285✔
1868
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
285✔
1869
                }
285✔
1870

1871
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
121✔
1872
                if err != nil {
121✔
1873
                        return nil, err
×
1874
                }
×
1875

1876
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
121✔
1877
                if err != nil {
121✔
1878
                        return nil, err
×
1879
                }
×
1880

1881
                tx.updatedRows++
121✔
1882
        }
1883
        return tx, nil
12✔
1884
}
1885

1886
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
121✔
1887
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
121✔
1888
        if err != nil {
121✔
1889
                return err
×
1890
        }
×
1891

1892
        for _, index := range table.indexes {
291✔
1893
                if !index.IsPrimary() {
219✔
1894
                        continue
49✔
1895
                }
1896

1897
                encodedValues := make([][]byte, 3+len(index.cols))
121✔
1898
                encodedValues[0] = EncodeID(DatabaseID)
121✔
1899
                encodedValues[1] = EncodeID(table.id)
121✔
1900
                encodedValues[2] = EncodeID(index.id)
121✔
1901

121✔
1902
                for i, col := range index.cols {
242✔
1903
                        val, specified := valuesByColID[col.id]
121✔
1904
                        if !specified {
121✔
1905
                                val = &NullValue{t: col.colType}
×
1906
                        }
×
1907

1908
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
121✔
1909

121✔
1910
                        encodedValues[i+3] = encVal
121✔
1911
                }
1912

1913
                md := store.NewKVMetadata()
121✔
1914

121✔
1915
                md.AsDeleted(true)
121✔
1916

121✔
1917
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
121✔
1918
                if err != nil {
121✔
1919
                        return err
×
1920
                }
×
1921
        }
1922

1923
        return nil
121✔
1924
}
1925

1926
type ValueExp interface {
1927
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1928
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1929
        substitute(params map[string]interface{}) (ValueExp, error)
1930
        selectors() []Selector
1931
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1932
        reduceSelectors(row *Row, implicitTable string) ValueExp
1933
        isConstant() bool
1934
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1935
        String() string
1936
}
1937

1938
type typedValueRange struct {
1939
        lRange *typedValueSemiRange
1940
        hRange *typedValueSemiRange
1941
}
1942

1943
type typedValueSemiRange struct {
1944
        val       TypedValue
1945
        inclusive bool
1946
}
1947

1948
func (r *typedValueRange) unitary() bool {
19✔
1949
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1950
        if r.lRange == nil || r.hRange == nil {
19✔
1951
                return false
×
1952
        }
×
1953

1954
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1955
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1956
}
1957

1958
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
5✔
1959
        if r.lRange == nil {
6✔
1960
                r.lRange = refiningRange.lRange
1✔
1961
        } else if r.lRange != nil && refiningRange.lRange != nil {
6✔
1962
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1963
                if err != nil {
1✔
1964
                        return err
×
1965
                }
×
1966
                r.lRange = maxRange
1✔
1967
        }
1968

1969
        if r.hRange == nil {
8✔
1970
                r.hRange = refiningRange.hRange
3✔
1971
        } else if r.hRange != nil && refiningRange.hRange != nil {
7✔
1972
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1973
                if err != nil {
2✔
1974
                        return err
×
1975
                }
×
1976
                r.hRange = minRange
2✔
1977
        }
1978

1979
        return nil
5✔
1980
}
1981

1982
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1983
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1984
                r.lRange = nil
2✔
1985
        } else {
5✔
1986
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1987
                if err != nil {
3✔
1988
                        return err
×
1989
                }
×
1990
                r.lRange = minRange
3✔
1991
        }
1992

1993
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1994
                r.hRange = nil
3✔
1995
        } else {
5✔
1996
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1997
                if err != nil {
2✔
1998
                        return err
×
1999
                }
×
2000
                r.hRange = maxRange
2✔
2001
        }
2002

2003
        return nil
5✔
2004
}
2005

2006
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
2007
        r, err := or1.val.Compare(or2.val)
3✔
2008
        if err != nil {
3✔
2009
                return nil, err
×
2010
        }
×
2011

2012
        maxVal := or1.val
3✔
2013
        if r < 0 {
5✔
2014
                maxVal = or2.val
2✔
2015
        }
2✔
2016

2017
        return &typedValueSemiRange{
3✔
2018
                val:       maxVal,
3✔
2019
                inclusive: or1.inclusive && or2.inclusive,
3✔
2020
        }, nil
3✔
2021
}
2022

2023
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
2024
        r, err := or1.val.Compare(or2.val)
5✔
2025
        if err != nil {
5✔
2026
                return nil, err
×
2027
        }
×
2028

2029
        minVal := or1.val
5✔
2030
        if r > 0 {
9✔
2031
                minVal = or2.val
4✔
2032
        }
4✔
2033

2034
        return &typedValueSemiRange{
5✔
2035
                val:       minVal,
5✔
2036
                inclusive: or1.inclusive || or2.inclusive,
5✔
2037
        }, nil
5✔
2038
}
2039

2040
type TypedValue interface {
2041
        ValueExp
2042
        Type() SQLValueType
2043
        RawValue() interface{}
2044
        Compare(val TypedValue) (int, error)
2045
        IsNull() bool
2046
}
2047

2048
type Tuple []TypedValue
2049

2050
func (t Tuple) Compare(other Tuple) (int, int, error) {
204,158✔
2051
        if len(t) != len(other) {
204,158✔
2052
                return -1, -1, ErrNotComparableValues
×
2053
        }
×
2054

2055
        for i := range t {
431,257✔
2056
                res, err := t[i].Compare(other[i])
227,099✔
2057
                if err != nil || res != 0 {
420,947✔
2058
                        return res, i, err
193,848✔
2059
                }
193,848✔
2060
        }
2061
        return 0, -1, nil
10,310✔
2062
}
2063

2064
func NewNull(t SQLValueType) *NullValue {
400✔
2065
        return &NullValue{t: t}
400✔
2066
}
400✔
2067

2068
type NullValue struct {
2069
        t SQLValueType
2070
}
2071

2072
func (n *NullValue) Type() SQLValueType {
107✔
2073
        return n.t
107✔
2074
}
107✔
2075

2076
func (n *NullValue) RawValue() interface{} {
363✔
2077
        return nil
363✔
2078
}
363✔
2079

2080
func (n *NullValue) IsNull() bool {
386✔
2081
        return true
386✔
2082
}
386✔
2083

2084
func (n *NullValue) String() string {
4✔
2085
        return "NULL"
4✔
2086
}
4✔
2087

2088
func (n *NullValue) Compare(val TypedValue) (int, error) {
84✔
2089
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
86✔
2090
                return 0, ErrNotComparableValues
2✔
2091
        }
2✔
2092

2093
        if val.RawValue() == nil {
123✔
2094
                return 0, nil
41✔
2095
        }
41✔
2096
        return -1, nil
41✔
2097
}
2098

2099
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
10✔
2100
        return v.t, nil
10✔
2101
}
10✔
2102

2103
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
12✔
2104
        if v.t == t {
19✔
2105
                return nil
7✔
2106
        }
7✔
2107

2108
        if v.t != AnyType {
6✔
2109
                return ErrInvalidTypes
1✔
2110
        }
1✔
2111

2112
        v.t = t
4✔
2113

4✔
2114
        return nil
4✔
2115
}
2116

2117
func (v *NullValue) selectors() []Selector {
17✔
2118
        return nil
17✔
2119
}
17✔
2120

2121
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
402✔
2122
        return v, nil
402✔
2123
}
402✔
2124

2125
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
355✔
2126
        return v, nil
355✔
2127
}
355✔
2128

2129
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
2130
        return v
10✔
2131
}
10✔
2132

2133
func (v *NullValue) isConstant() bool {
12✔
2134
        return true
12✔
2135
}
12✔
2136

2137
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2138
        return nil
1✔
2139
}
1✔
2140

2141
type Integer struct {
2142
        val int64
2143
}
2144

2145
func NewInteger(val int64) *Integer {
315✔
2146
        return &Integer{val: val}
315✔
2147
}
315✔
2148

2149
func (v *Integer) Type() SQLValueType {
302,817✔
2150
        return IntegerType
302,817✔
2151
}
302,817✔
2152

2153
func (v *Integer) IsNull() bool {
116,552✔
2154
        return false
116,552✔
2155
}
116,552✔
2156

2157
func (v *Integer) String() string {
54✔
2158
        return strconv.FormatInt(v.val, 10)
54✔
2159
}
54✔
2160

2161
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
99✔
2162
        return IntegerType, nil
99✔
2163
}
99✔
2164

2165
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
64✔
2166
        if t != IntegerType && t != JSONType {
68✔
2167
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
2168
        }
4✔
2169

60✔
2170
        return nil
2171
}
2172

54✔
2173
func (v *Integer) selectors() []Selector {
54✔
2174
        return nil
54✔
2175
}
2176

15,163✔
2177
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
15,163✔
2178
        return v, nil
15,163✔
2179
}
2180

16,913✔
2181
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
16,913✔
2182
        return v, nil
16,913✔
2183
}
2184

9✔
2185
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
9✔
2186
        return v
9✔
2187
}
2188

120✔
2189
func (v *Integer) isConstant() bool {
120✔
2190
        return true
120✔
2191
}
2192

1✔
2193
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2194
        return nil
1✔
2195
}
2196

177,192✔
2197
func (v *Integer) RawValue() interface{} {
177,192✔
2198
        return v.val
177,192✔
2199
}
2200

93,041✔
2201
func (v *Integer) Compare(val TypedValue) (int, error) {
93,084✔
2202
        if val.IsNull() {
43✔
2203
                return 1, nil
43✔
2204
        }
2205

92,999✔
2206
        if val.Type() == JSONType {
1✔
2207
                res, err := val.Compare(v)
1✔
2208
                return -res, err
1✔
2209
        }
2210

92,997✔
UNCOV
2211
        if val.Type() == Float64Type {
×
2212
                r, err := val.Compare(v)
×
2213
                return r * -1, err
×
2214
        }
2215

93,004✔
2216
        if val.Type() != IntegerType {
7✔
2217
                return 0, ErrNotComparableValues
7✔
2218
        }
2219

92,990✔
2220
        rval := val.RawValue().(int64)
92,990✔
2221

111,638✔
2222
        if v.val == rval {
18,648✔
2223
                return 0, nil
18,648✔
2224
        }
2225

109,218✔
2226
        if v.val > rval {
34,876✔
2227
                return 1, nil
34,876✔
2228
        }
2229

39,466✔
2230
        return -1, nil
2231
}
2232

2233
type Timestamp struct {
2234
        val time.Time
2235
}
2236

40,315✔
2237
func (v *Timestamp) Type() SQLValueType {
40,315✔
2238
        return TimestampType
40,315✔
2239
}
2240

32,894✔
2241
func (v *Timestamp) IsNull() bool {
32,894✔
2242
        return false
32,894✔
2243
}
2244

1✔
2245
func (v *Timestamp) String() string {
1✔
2246
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2247
}
2248

1✔
2249
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2250
        return TimestampType, nil
1✔
2251
}
2252

14✔
2253
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
2254
        if t != TimestampType {
1✔
2255
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2256
        }
2257

13✔
2258
        return nil
2259
}
2260

1✔
2261
func (v *Timestamp) selectors() []Selector {
1✔
2262
        return nil
1✔
2263
}
2264

1,165✔
2265
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
1,165✔
2266
        return v, nil
1,165✔
2267
}
2268

2,025✔
2269
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,025✔
2270
        return v, nil
2,025✔
2271
}
2272

1✔
2273
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2274
        return v
1✔
2275
}
2276

1✔
2277
func (v *Timestamp) isConstant() bool {
1✔
2278
        return true
1✔
2279
}
2280

1✔
2281
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2282
        return nil
1✔
2283
}
2284

57,410✔
2285
func (v *Timestamp) RawValue() interface{} {
57,410✔
2286
        return v.val
57,410✔
2287
}
2288

29,744✔
2289
func (v *Timestamp) Compare(val TypedValue) (int, error) {
29,746✔
2290
        if val.IsNull() {
2✔
2291
                return 1, nil
2✔
2292
        }
2293

29,743✔
2294
        if val.Type() != TimestampType {
1✔
2295
                return 0, ErrNotComparableValues
1✔
2296
        }
2297

29,741✔
2298
        rval := val.RawValue().(time.Time)
29,741✔
2299

44,251✔
2300
        if v.val.Before(rval) {
14,510✔
2301
                return -1, nil
14,510✔
2302
        }
2303

30,271✔
2304
        if v.val.After(rval) {
15,040✔
2305
                return 1, nil
15,040✔
2306
        }
2307

191✔
2308
        return 0, nil
2309
}
2310

2311
type Varchar struct {
2312
        val string
2313
}
2314

2,137✔
2315
func NewVarchar(val string) *Varchar {
2,137✔
2316
        return &Varchar{val: val}
2,137✔
2317
}
2318

127,949✔
2319
func (v *Varchar) Type() SQLValueType {
127,949✔
2320
        return VarcharType
127,949✔
2321
}
2322

64,576✔
2323
func (v *Varchar) IsNull() bool {
64,576✔
2324
        return false
64,576✔
2325
}
2326

18✔
2327
func (v *Varchar) String() string {
18✔
2328
        return fmt.Sprintf("'%s'", v.val)
18✔
2329
}
2330

62✔
2331
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
62✔
2332
        return VarcharType, nil
62✔
2333
}
2334

143✔
2335
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
147✔
2336
        if t != VarcharType && t != JSONType {
4✔
2337
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
4✔
2338
        }
139✔
2339
        return nil
2340
}
2341

43✔
2342
func (v *Varchar) selectors() []Selector {
43✔
2343
        return nil
43✔
2344
}
2345

4,955✔
2346
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
4,955✔
2347
        return v, nil
4,955✔
2348
}
2349

5,640✔
2350
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5,640✔
2351
        return v, nil
5,640✔
2352
}
UNCOV
2353

×
2354
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2355
        return v
×
2356
}
2357

39✔
2358
func (v *Varchar) isConstant() bool {
39✔
2359
        return true
39✔
2360
}
2361

1✔
2362
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2363
        return nil
1✔
2364
}
2365

90,158✔
2366
func (v *Varchar) RawValue() interface{} {
90,158✔
2367
        return v.val
90,158✔
2368
}
2369

58,120✔
2370
func (v *Varchar) Compare(val TypedValue) (int, error) {
58,174✔
2371
        if val.IsNull() {
54✔
2372
                return 1, nil
54✔
2373
        }
2374

59,067✔
2375
        if val.Type() == JSONType {
1,001✔
2376
                res, err := val.Compare(v)
1,001✔
2377
                return -res, err
1,001✔
2378
        }
2379

57,066✔
2380
        if val.Type() != VarcharType {
1✔
2381
                return 0, ErrNotComparableValues
1✔
2382
        }
2383

57,064✔
2384
        rval := val.RawValue().(string)
57,064✔
2385

57,064✔
2386
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
2387
}
2388

2389
type UUID struct {
2390
        val uuid.UUID
2391
}
2392

1✔
2393
func NewUUID(val uuid.UUID) *UUID {
1✔
2394
        return &UUID{val: val}
1✔
2395
}
2396

10✔
2397
func (v *UUID) Type() SQLValueType {
10✔
2398
        return UUIDType
10✔
2399
}
2400

26✔
2401
func (v *UUID) IsNull() bool {
26✔
2402
        return false
26✔
2403
}
2404

1✔
2405
func (v *UUID) String() string {
1✔
2406
        return v.val.String()
1✔
2407
}
2408

1✔
2409
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2410
        return UUIDType, nil
1✔
2411
}
2412

4✔
2413
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
2414
        if t != UUIDType {
2✔
2415
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2416
        }
2417

2✔
2418
        return nil
2419
}
2420

1✔
2421
func (v *UUID) selectors() []Selector {
1✔
2422
        return nil
1✔
2423
}
2424

6✔
2425
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
6✔
2426
        return v, nil
6✔
2427
}
2428

5✔
2429
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5✔
2430
        return v, nil
5✔
2431
}
2432

1✔
2433
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2434
        return v
1✔
2435
}
2436

1✔
2437
func (v *UUID) isConstant() bool {
1✔
2438
        return true
1✔
2439
}
2440

1✔
2441
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2442
        return nil
1✔
2443
}
2444

41✔
2445
func (v *UUID) RawValue() interface{} {
41✔
2446
        return v.val
41✔
2447
}
2448

5✔
2449
func (v *UUID) Compare(val TypedValue) (int, error) {
7✔
2450
        if val.IsNull() {
2✔
2451
                return 1, nil
2✔
2452
        }
2453

4✔
2454
        if val.Type() != UUIDType {
1✔
2455
                return 0, ErrNotComparableValues
1✔
2456
        }
2457

2✔
2458
        rval := val.RawValue().(uuid.UUID)
2✔
2459

2✔
2460
        return bytes.Compare(v.val[:], rval[:]), nil
2461
}
2462

2463
type Bool struct {
2464
        val bool
2465
}
2466

208✔
2467
func NewBool(val bool) *Bool {
208✔
2468
        return &Bool{val: val}
208✔
2469
}
2470

1,966✔
2471
func (v *Bool) Type() SQLValueType {
1,966✔
2472
        return BooleanType
1,966✔
2473
}
2474

1,479✔
2475
func (v *Bool) IsNull() bool {
1,479✔
2476
        return false
1,479✔
2477
}
2478

41✔
2479
func (v *Bool) String() string {
41✔
2480
        return strconv.FormatBool(v.val)
41✔
2481
}
2482

32✔
2483
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
32✔
2484
        return BooleanType, nil
32✔
2485
}
2486

57✔
2487
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
62✔
2488
        if t != BooleanType && t != JSONType {
5✔
2489
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2490
        }
52✔
2491
        return nil
2492
}
2493

5✔
2494
func (v *Bool) selectors() []Selector {
5✔
2495
        return nil
5✔
2496
}
2497

639✔
2498
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
639✔
2499
        return v, nil
639✔
2500
}
2501

726✔
2502
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2503
        return v, nil
726✔
2504
}
UNCOV
2505

×
2506
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2507
        return v
×
2508
}
2509

3✔
2510
func (v *Bool) isConstant() bool {
3✔
2511
        return true
3✔
2512
}
2513

8✔
2514
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
2515
        return nil
8✔
2516
}
2517

1,712✔
2518
func (v *Bool) RawValue() interface{} {
1,712✔
2519
        return v.val
1,712✔
2520
}
2521

575✔
2522
func (v *Bool) Compare(val TypedValue) (int, error) {
605✔
2523
        if val.IsNull() {
30✔
2524
                return 1, nil
30✔
2525
        }
2526

546✔
2527
        if val.Type() == JSONType {
1✔
2528
                res, err := val.Compare(v)
1✔
2529
                return -res, err
1✔
2530
        }
2531

544✔
UNCOV
2532
        if val.Type() != BooleanType {
×
2533
                return 0, ErrNotComparableValues
×
2534
        }
2535

544✔
2536
        rval := val.RawValue().(bool)
544✔
2537

888✔
2538
        if v.val == rval {
344✔
2539
                return 0, nil
344✔
2540
        }
2541

206✔
2542
        if v.val {
6✔
2543
                return 1, nil
6✔
2544
        }
2545

194✔
2546
        return -1, nil
2547
}
2548

2549
type Blob struct {
2550
        val []byte
2551
}
2552

286✔
2553
func NewBlob(val []byte) *Blob {
286✔
2554
        return &Blob{val: val}
286✔
2555
}
2556

53✔
2557
func (v *Blob) Type() SQLValueType {
53✔
2558
        return BLOBType
53✔
2559
}
2560

2,312✔
2561
func (v *Blob) IsNull() bool {
2,312✔
2562
        return false
2,312✔
2563
}
2564

2✔
2565
func (v *Blob) String() string {
2✔
2566
        return hex.EncodeToString(v.val)
2✔
2567
}
2568

1✔
2569
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2570
        return BLOBType, nil
1✔
2571
}
2572

2✔
2573
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
3✔
2574
        if t != BLOBType {
1✔
2575
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2576
        }
2577

1✔
2578
        return nil
2579
}
2580

1✔
2581
func (v *Blob) selectors() []Selector {
1✔
2582
        return nil
1✔
2583
}
2584

714✔
2585
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
714✔
2586
        return v, nil
714✔
2587
}
2588

726✔
2589
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2590
        return v, nil
726✔
2591
}
UNCOV
2592

×
2593
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2594
        return v
×
2595
}
2596

7✔
2597
func (v *Blob) isConstant() bool {
7✔
2598
        return true
7✔
2599
}
UNCOV
2600

×
2601
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2602
        return nil
×
2603
}
2604

2,570✔
2605
func (v *Blob) RawValue() interface{} {
2,570✔
2606
        return v.val
2,570✔
2607
}
2608

25✔
2609
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
UNCOV
2610
        if val.IsNull() {
×
2611
                return 1, nil
×
2612
        }
2613

25✔
UNCOV
2614
        if val.Type() != BLOBType {
×
2615
                return 0, ErrNotComparableValues
×
2616
        }
2617

25✔
2618
        rval := val.RawValue().([]byte)
25✔
2619

25✔
2620
        return bytes.Compare(v.val, rval), nil
2621
}
2622

2623
type Float64 struct {
2624
        val float64
2625
}
2626

1,249✔
2627
func NewFloat64(val float64) *Float64 {
1,249✔
2628
        return &Float64{val: val}
1,249✔
2629
}
2630

207,440✔
2631
func (v *Float64) Type() SQLValueType {
207,440✔
2632
        return Float64Type
207,440✔
2633
}
2634

5,599✔
2635
func (v *Float64) IsNull() bool {
5,599✔
2636
        return false
5,599✔
2637
}
2638

3✔
2639
func (v *Float64) String() string {
3✔
2640
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2641
}
2642

15✔
2643
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2644
        return Float64Type, nil
15✔
2645
}
2646

20✔
2647
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
21✔
2648
        if t != Float64Type && t != JSONType {
1✔
2649
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2650
        }
19✔
2651
        return nil
2652
}
2653

3✔
2654
func (v *Float64) selectors() []Selector {
3✔
2655
        return nil
3✔
2656
}
2657

1,818✔
2658
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
1,818✔
2659
        return v, nil
1,818✔
2660
}
2661

3,429✔
2662
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,429✔
2663
        return v, nil
3,429✔
2664
}
2665

1✔
2666
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2667
        return v
1✔
2668
}
2669

5✔
2670
func (v *Float64) isConstant() bool {
5✔
2671
        return true
5✔
2672
}
2673

1✔
2674
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2675
        return nil
1✔
2676
}
2677

365,970✔
2678
func (v *Float64) RawValue() interface{} {
365,970✔
2679
        return v.val
365,970✔
2680
}
2681

61,877✔
2682
func (v *Float64) Compare(val TypedValue) (int, error) {
61,878✔
2683
        if val.Type() == JSONType {
1✔
2684
                res, err := val.Compare(v)
1✔
2685
                return -res, err
1✔
2686
        }
2687

61,876✔
2688
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
61,877✔
2689
        if err != nil {
1✔
2690
                return 0, err
1✔
2691
        }
2692

61,878✔
2693
        if convVal == nil {
3✔
2694
                return 1, nil
3✔
2695
        }
2696

61,872✔
2697
        rval, ok := convVal.(float64)
61,872✔
UNCOV
2698
        if !ok {
×
2699
                return 0, ErrNotComparableValues
×
2700
        }
2701

61,999✔
2702
        if v.val == rval {
127✔
2703
                return 0, nil
127✔
2704
        }
2705

90,618✔
2706
        if v.val > rval {
28,873✔
2707
                return 1, nil
28,873✔
2708
        }
2709

32,872✔
2710
        return -1, nil
2711
}
2712

2713
type FnCall struct {
2714
        fn     string
2715
        params []ValueExp
2716
}
2717

24✔
2718
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
24✔
2719
        fn, err := v.resolveFunc()
25✔
2720
        if err != nil {
1✔
2721
                return AnyType, nil
1✔
2722
        }
23✔
2723
        return fn.InferType(cols, params, implicitTable)
2724
}
2725

20✔
2726
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2727
        fn, err := v.resolveFunc()
21✔
2728
        if err != nil {
1✔
2729
                return err
1✔
2730
        }
19✔
2731
        return fn.RequiresType(t, cols, params, implicitTable)
2732
}
2733

38✔
2734
func (v *FnCall) selectors() []Selector {
38✔
2735
        selectors := make([]Selector, 0)
106✔
2736
        for _, param := range v.params {
68✔
2737
                selectors = append(selectors, param.selectors()...)
68✔
2738
        }
38✔
2739
        return selectors
2740
}
2741

438✔
2742
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
438✔
2743
        ps := make([]ValueExp, len(v.params))
808✔
2744
        for i, p := range v.params {
370✔
2745
                ps[i], err = p.substitute(params)
370✔
UNCOV
2746
                if err != nil {
×
2747
                        return nil, err
×
2748
                }
2749
        }
2750

438✔
2751
        return &FnCall{
438✔
2752
                fn:     v.fn,
438✔
2753
                params: ps,
438✔
2754
        }, nil
2755
}
2756

438✔
2757
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
438✔
2758
        fn, err := v.resolveFunc()
439✔
2759
        if err != nil {
1✔
2760
                return nil, err
1✔
2761
        }
2762

437✔
2763
        fnInputs, err := v.reduceParams(tx, row, implicitTable)
437✔
UNCOV
2764
        if err != nil {
×
2765
                return nil, err
×
2766
        }
437✔
2767
        return fn.Apply(tx, fnInputs)
2768
}
2769

437✔
2770
func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) {
437✔
2771
        var values []TypedValue
773✔
2772
        if len(v.params) > 0 {
336✔
2773
                values = make([]TypedValue, len(v.params))
706✔
2774
                for i, p := range v.params {
370✔
2775
                        v, err := p.reduce(tx, row, implicitTable)
370✔
UNCOV
2776
                        if err != nil {
×
2777
                                return nil, err
×
2778
                        }
370✔
2779
                        values[i] = v
2780
                }
2781
        }
437✔
2782
        return values, nil
2783
}
2784

482✔
2785
func (v *FnCall) resolveFunc() (Function, error) {
482✔
2786
        fn, exists := builtinFunctions[strings.ToUpper(v.fn)]
485✔
2787
        if !exists {
3✔
2788
                return nil, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
3✔
2789
        }
479✔
2790
        return fn, nil
2791
}
UNCOV
2792

×
2793
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2794
        return v
×
2795
}
2796

13✔
2797
func (v *FnCall) isConstant() bool {
13✔
2798
        return false
13✔
2799
}
UNCOV
2800

×
2801
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2802
        return nil
×
2803
}
2804

1✔
2805
func (v *FnCall) String() string {
1✔
2806
        params := make([]string, len(v.params))
4✔
2807
        for i, p := range v.params {
3✔
2808
                params[i] = p.String()
3✔
2809
        }
1✔
2810
        return v.fn + "(" + strings.Join(params, ",") + ")"
2811
}
2812

2813
type Cast struct {
2814
        val ValueExp
2815
        t   SQLValueType
2816
}
2817

23✔
2818
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2819
        _, err := c.val.inferType(cols, params, implicitTable)
24✔
2820
        if err != nil {
1✔
2821
                return AnyType, err
1✔
2822
        }
2823

2824
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2825

22✔
2826
        return c.t, nil
2827
}
UNCOV
2828

×
2829
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2830
        if c.t != t {
×
2831
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2832
        }
UNCOV
2833

×
2834
        return nil
2835
}
2836

281✔
2837
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
281✔
2838
        val, err := c.val.substitute(params)
281✔
UNCOV
2839
        if err != nil {
×
2840
                return nil, err
×
2841
        }
281✔
2842
        c.val = val
281✔
2843
        return c, nil
2844
}
2845

269✔
2846
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
269✔
2847
        val, err := c.val.reduce(tx, row, implicitTable)
269✔
UNCOV
2848
        if err != nil {
×
2849
                return nil, err
×
2850
        }
2851

269✔
2852
        conv, err := getConverter(val.Type(), c.t)
272✔
2853
        if conv == nil {
3✔
2854
                return nil, err
3✔
2855
        }
2856

266✔
2857
        return conv(val)
2858
}
2859

6✔
2860
func (v *Cast) selectors() []Selector {
6✔
2861
        return v.val.selectors()
6✔
2862
}
UNCOV
2863

×
2864
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2865
        return &Cast{
×
2866
                val: c.val.reduceSelectors(row, implicitTable),
×
2867
                t:   c.t,
×
2868
        }
×
2869
}
2870

7✔
2871
func (c *Cast) isConstant() bool {
7✔
2872
        return c.val.isConstant()
7✔
2873
}
UNCOV
2874

×
2875
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2876
        return nil
×
2877
}
2878

1✔
2879
func (c *Cast) String() string {
1✔
2880
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2881
}
2882

2883
type Param struct {
2884
        id  string
2885
        pos int
2886
}
2887

58✔
2888
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
58✔
2889
        t, ok := params[v.id]
114✔
2890
        if !ok {
56✔
2891
                params[v.id] = AnyType
56✔
2892
                return AnyType, nil
56✔
2893
        }
2894

2✔
2895
        return t, nil
2896
}
2897

76✔
2898
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2899
        currT, ok := params[v.id]
80✔
2900
        if ok && currT != t && currT != AnyType {
4✔
2901
                return ErrInferredMultipleTypes
4✔
2902
        }
2903

72✔
2904
        params[v.id] = t
72✔
2905

72✔
2906
        return nil
2907
}
2908

6,399✔
2909
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
6,399✔
2910
        val, ok := params[p.id]
6,461✔
2911
        if !ok {
62✔
2912
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2913
        }
2914

6,386✔
2915
        if val == nil {
49✔
2916
                return &NullValue{t: AnyType}, nil
49✔
2917
        }
2918

6,288✔
2919
        switch v := val.(type) {
96✔
2920
        case bool:
192✔
2921
                {
96✔
2922
                        return &Bool{val: v}, nil
96✔
2923
                }
1,752✔
2924
        case string:
3,504✔
2925
                {
1,752✔
2926
                        return &Varchar{val: v}, nil
1,752✔
2927
                }
1,678✔
2928
        case int:
3,356✔
2929
                {
1,678✔
2930
                        return &Integer{val: int64(v)}, nil
1,678✔
UNCOV
2931
                }
×
2932
        case uint:
×
2933
                {
×
2934
                        return &Integer{val: int64(v)}, nil
×
2935
                }
34✔
2936
        case uint64:
68✔
2937
                {
34✔
2938
                        return &Integer{val: int64(v)}, nil
34✔
2939
                }
227✔
2940
        case int64:
454✔
2941
                {
227✔
2942
                        return &Integer{val: v}, nil
227✔
2943
                }
14✔
2944
        case []byte:
28✔
2945
                {
14✔
2946
                        return &Blob{val: v}, nil
14✔
2947
                }
861✔
2948
        case time.Time:
1,722✔
2949
                {
861✔
2950
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
861✔
2951
                }
1,625✔
2952
        case float64:
3,250✔
2953
                {
1,625✔
2954
                        return &Float64{val: v}, nil
1,625✔
2955
                }
2956
        }
1✔
2957
        return nil, ErrUnsupportedParameter
2958
}
UNCOV
2959

×
2960
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2961
        return nil, ErrUnexpected
×
2962
}
2963

4✔
2964
func (p *Param) selectors() []Selector {
4✔
2965
        return nil
4✔
2966
}
UNCOV
2967

×
2968
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2969
        return p
×
2970
}
2971

130✔
2972
func (p *Param) isConstant() bool {
130✔
2973
        return true
130✔
2974
}
2975

6✔
2976
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
6✔
2977
        return nil
6✔
2978
}
2979

2✔
2980
func (v *Param) String() string {
2✔
2981
        return "@" + v.id
2✔
2982
}
2983

2984
type whenThenClause struct {
2985
        when, then ValueExp
2986
}
2987

2988
type CaseWhenExp struct {
2989
        exp      ValueExp
2990
        whenThen []whenThenClause
2991
        elseExp  ValueExp
2992
}
2993

8✔
2994
func (ce *CaseWhenExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
18✔
2995
        checkType := func(e ValueExp, expectedType SQLValueType) (string, error) {
10✔
2996
                t, err := e.inferType(cols, params, implicitTable)
10✔
UNCOV
2997
                if err != nil {
×
2998
                        return "", err
×
2999
                }
3000

15✔
3001
                if expectedType == AnyType {
5✔
3002
                        return t, nil
5✔
3003
                }
3004

6✔
3005
                if t != expectedType {
1✔
3006
                        if (t == Float64Type && expectedType == IntegerType) ||
1✔
UNCOV
3007
                                (t == IntegerType && expectedType == Float64Type) {
×
3008
                                return Float64Type, nil
×
3009
                        }
1✔
3010
                        return "", fmt.Errorf("%w: CASE types %s and %s cannot be matched", ErrInferredMultipleTypes, expectedType, t)
3011
                }
4✔
3012
                return t, nil
3013
        }
3014

8✔
3015
        searchType := BooleanType
8✔
3016
        inferredResType := AnyType
11✔
3017
        if ce.exp != nil {
3✔
3018
                t, err := ce.exp.inferType(cols, params, implicitTable)
3✔
UNCOV
3019
                if err != nil {
×
3020
                        return "", err
×
3021
                }
3✔
3022
                searchType = t
3023
        }
3024

16✔
3025
        for _, e := range ce.whenThen {
8✔
3026
                whenType, err := e.when.inferType(cols, params, implicitTable)
8✔
UNCOV
3027
                if err != nil {
×
3028
                        return "", err
×
3029
                }
3030

11✔
3031
                if whenType != searchType {
3✔
3032
                        return "", fmt.Errorf("%w: argument of CASE/WHEN must be of type %s, not type %s", ErrInvalidTypes, searchType, whenType)
3✔
3033
                }
3034

5✔
3035
                t, err := checkType(e.then, inferredResType)
5✔
UNCOV
3036
                if err != nil {
×
3037
                        return "", err
×
3038
                }
5✔
3039
                inferredResType = t
3040
        }
3041

10✔
3042
        if ce.elseExp != nil {
5✔
3043
                return checkType(ce.elseExp, inferredResType)
5✔
UNCOV
3044
        }
×
3045
        return inferredResType, nil
3046
}
3047

2✔
3048
func (ce *CaseWhenExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
3049
        inferredType, err := ce.inferType(cols, params, implicitTable)
3✔
3050
        if err != nil {
1✔
3051
                return err
1✔
3052
        }
3053

1✔
UNCOV
3054
        if inferredType != t {
×
3055
                return fmt.Errorf("%w: expected type %s but %s found instead", ErrInvalidTypes, t, inferredType)
×
3056
        }
1✔
3057
        return nil
3058
}
3059

504✔
3060
func (ce *CaseWhenExp) substitute(params map[string]interface{}) (ValueExp, error) {
504✔
3061
        var exp ValueExp
605✔
3062
        if ce.exp != nil {
101✔
3063
                e, err := ce.exp.substitute(params)
101✔
UNCOV
3064
                if err != nil {
×
3065
                        return nil, err
×
3066
                }
101✔
3067
                exp = e
3068
        }
3069

504✔
3070
        whenThen := make([]whenThenClause, len(ce.whenThen))
1,208✔
3071
        for i, wt := range ce.whenThen {
704✔
3072
                whenValue, err := wt.when.substitute(params)
704✔
UNCOV
3073
                if err != nil {
×
3074
                        return nil, err
×
3075
                }
704✔
3076
                whenThen[i].when = whenValue
704✔
3077

704✔
3078
                thenValue, err := wt.then.substitute(params)
704✔
UNCOV
3079
                if err != nil {
×
3080
                        return nil, err
×
3081
                }
704✔
3082
                whenThen[i].then = thenValue
3083
        }
3084

506✔
3085
        if ce.elseExp == nil {
2✔
3086
                return &CaseWhenExp{
2✔
3087
                        exp:      exp,
2✔
3088
                        whenThen: whenThen,
2✔
3089
                }, nil
2✔
3090
        }
3091

502✔
3092
        elseValue, err := ce.elseExp.substitute(params)
502✔
UNCOV
3093
        if err != nil {
×
3094
                return nil, err
×
3095
        }
502✔
3096
        return &CaseWhenExp{
502✔
3097
                exp:      exp,
502✔
3098
                whenThen: whenThen,
502✔
3099
                elseExp:  elseValue,
502✔
3100
        }, nil
3101
}
3102

7✔
3103
func (ce *CaseWhenExp) selectors() []Selector {
7✔
3104
        selectors := make([]Selector, 0)
8✔
3105
        if ce.exp != nil {
1✔
3106
                selectors = append(selectors, ce.exp.selectors()...)
1✔
3107
        }
3108

16✔
3109
        for _, wh := range ce.whenThen {
9✔
3110
                selectors = append(selectors, wh.when.selectors()...)
9✔
3111
                selectors = append(selectors, wh.then.selectors()...)
9✔
3112
        }
3113

9✔
3114
        if ce.elseExp == nil {
2✔
3115
                return selectors
2✔
3116
        }
5✔
3117
        return append(selectors, ce.elseExp.selectors()...)
3118
}
3119

304✔
3120
func (ce *CaseWhenExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
304✔
3121
        var searchValue TypedValue
405✔
3122
        if ce.exp != nil {
101✔
3123
                v, err := ce.exp.reduce(tx, row, implicitTable)
101✔
UNCOV
3124
                if err != nil {
×
3125
                        return nil, err
×
3126
                }
101✔
3127
                searchValue = v
203✔
3128
        } else {
203✔
3129
                searchValue = &Bool{val: true}
203✔
3130
        }
3131

734✔
3132
        for _, wt := range ce.whenThen {
430✔
3133
                v, err := wt.when.reduce(tx, row, implicitTable)
430✔
UNCOV
3134
                if err != nil {
×
3135
                        return nil, err
×
3136
                }
3137

431✔
3138
                if v.Type() != searchValue.Type() {
1✔
3139
                        return nil, fmt.Errorf("%w: argument of CASE/WHEN must be type %s, not type %s", ErrInvalidTypes, v.Type(), searchValue.Type())
1✔
3140
                }
3141

429✔
3142
                res, err := v.Compare(searchValue)
429✔
UNCOV
3143
                if err != nil {
×
3144
                        return nil, err
×
3145
                }
629✔
3146
                if res == 0 {
200✔
3147
                        return wt.then.reduce(tx, row, implicitTable)
200✔
3148
                }
3149
        }
3150

104✔
3151
        if ce.elseExp == nil {
1✔
3152
                return NewNull(AnyType), nil
1✔
3153
        }
102✔
3154
        return ce.elseExp.reduce(tx, row, implicitTable)
3155
}
3156

1✔
3157
func (ce *CaseWhenExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3158
        whenThen := make([]whenThenClause, len(ce.whenThen))
2✔
3159
        for i, wt := range ce.whenThen {
1✔
3160
                whenValue := wt.when.reduceSelectors(row, implicitTable)
1✔
3161
                whenThen[i].when = whenValue
1✔
3162

1✔
3163
                thenValue := wt.then.reduceSelectors(row, implicitTable)
1✔
3164
                whenThen[i].then = thenValue
1✔
3165
        }
3166

1✔
UNCOV
3167
        if ce.elseExp == nil {
×
3168
                return &CaseWhenExp{
×
3169
                        whenThen: whenThen,
×
3170
                }
×
3171
        }
3172

1✔
3173
        return &CaseWhenExp{
1✔
3174
                whenThen: whenThen,
1✔
3175
                elseExp:  ce.elseExp.reduceSelectors(row, implicitTable),
1✔
3176
        }
3177
}
3178

1✔
3179
func (ce *CaseWhenExp) isConstant() bool {
1✔
3180
        return false
1✔
3181
}
3182

1✔
3183
func (ce *CaseWhenExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
3184
        return nil
1✔
3185
}
3186

3✔
3187
func (ce *CaseWhenExp) String() string {
3✔
3188
        var sb strings.Builder
7✔
3189
        for _, wh := range ce.whenThen {
4✔
3190
                sb.WriteString(fmt.Sprintf("WHEN %s THEN %s ", wh.when.String(), wh.then.String()))
4✔
3191
        }
3192

5✔
3193
        if ce.elseExp != nil {
2✔
3194
                sb.WriteString("ELSE " + ce.elseExp.String() + " ")
2✔
3195
        }
3✔
3196
        return "CASE " + sb.String() + "END"
3197
}
3198

3199
type Comparison int
3200

3201
const (
3202
        EqualTo Comparison = iota
3203
        LowerThan
3204
        LowerOrEqualTo
3205
        GreaterThan
3206
        GreaterOrEqualTo
3207
)
3208

3209
type DataSource interface {
3210
        SQLStmt
3211
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
3212
        Alias() string
3213
}
3214

3215
type TargetEntry struct {
3216
        Exp ValueExp
3217
        As  string
3218
}
3219

3220
type SelectStmt struct {
3221
        distinct  bool
3222
        targets   []TargetEntry
3223
        selectors []Selector
3224
        ds        DataSource
3225
        indexOn   []string
3226
        joins     []*JoinSpec
3227
        where     ValueExp
3228
        groupBy   []*ColSelector
3229
        having    ValueExp
3230
        orderBy   []*OrdExp
3231
        limit     ValueExp
3232
        offset    ValueExp
3233
        as        string
3234
}
3235

3236
func NewSelectStmt(
3237
        targets []TargetEntry,
3238
        ds DataSource,
3239
        where ValueExp,
3240
        orderBy []*OrdExp,
3241
        limit ValueExp,
3242
        offset ValueExp,
71✔
3243
) *SelectStmt {
71✔
3244
        return &SelectStmt{
71✔
3245
                targets: targets,
71✔
3246
                ds:      ds,
71✔
3247
                where:   where,
71✔
3248
                orderBy: orderBy,
71✔
3249
                limit:   limit,
71✔
3250
                offset:  offset,
71✔
3251
        }
71✔
3252
}
3253

94✔
3254
func (stmt *SelectStmt) readOnly() bool {
94✔
3255
        return true
94✔
3256
}
3257

96✔
3258
func (stmt *SelectStmt) requiredPrivileges() []SQLPrivilege {
96✔
3259
        return []SQLPrivilege{SQLPrivilegeSelect}
96✔
3260
}
3261

53✔
3262
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
53✔
3263
        _, err := stmt.execAt(ctx, tx, nil)
53✔
UNCOV
3264
        if err != nil {
×
3265
                return err
×
3266
        }
3267

3268
        // TODO: (jeroiraz) may be optimized so to resolve the query statement just once
53✔
3269
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
54✔
3270
        if err != nil {
1✔
3271
                return err
1✔
3272
        }
52✔
3273
        defer rowReader.Close()
52✔
3274

52✔
3275
        return rowReader.InferParameters(ctx, params)
3276
}
3277

630✔
3278
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
631✔
3279
        if stmt.groupBy == nil && stmt.having != nil {
1✔
3280
                return nil, ErrHavingClauseRequiresGroupClause
1✔
3281
        }
3282

714✔
3283
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
227✔
3284
                for _, sel := range stmt.targetSelectors() {
142✔
3285
                        _, isAgg := sel.(*AggColSelector)
144✔
3286
                        if !isAgg && !stmt.groupByContains(sel) {
2✔
3287
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3288
                        }
3289
                }
3290
        }
3291

784✔
3292
        if len(stmt.orderBy) > 0 {
352✔
3293
                for _, col := range stmt.orderBy {
380✔
3294
                        for _, sel := range col.exp.selectors() {
185✔
3295
                                _, isAgg := sel.(*AggColSelector)
187✔
3296
                                if (isAgg && !stmt.selectorAppearsInTargets(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
2✔
3297
                                        return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3298
                                }
3299
                        }
3300
                }
3301
        }
625✔
3302
        return tx, nil
3303
}
3304

2,542✔
3305
func (stmt *SelectStmt) targetSelectors() []Selector {
3,503✔
3306
        if stmt.selectors == nil {
961✔
3307
                stmt.selectors = stmt.extractSelectors()
961✔
3308
        }
2,542✔
3309
        return stmt.selectors
3310
}
3311

4✔
3312
func (stmt *SelectStmt) selectorAppearsInTargets(s Selector) bool {
4✔
3313
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
3314

12✔
3315
        for _, sel := range stmt.targetSelectors() {
11✔
3316
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
3✔
3317
                        return true
3✔
3318
                }
3319
        }
1✔
3320
        return false
3321
}
3322

57✔
3323
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
3324
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
3325

137✔
3326
        for _, colSel := range stmt.groupBy {
134✔
3327
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
54✔
3328
                        return true
54✔
3329
                }
3330
        }
3✔
3331
        return false
3332
}
3333

80✔
3334
func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector {
80✔
3335
        cols := make([]*AggColSelector, 0, len(stmt.targets))
80✔
3336

214✔
3337
        for _, t := range stmt.targets {
134✔
3338
                selectors := t.Exp.selectors()
268✔
3339
                for _, sel := range selectors {
134✔
3340
                        aggSel, isAgg := sel.(*AggColSelector)
244✔
3341
                        if isAgg {
110✔
3342
                                cols = append(cols, aggSel)
110✔
3343
                        }
3344
                }
3345
        }
80✔
3346
        return cols
3347
}
3348

961✔
3349
func (stmt *SelectStmt) extractSelectors() []Selector {
961✔
3350
        selectors := make([]Selector, 0, len(stmt.targets))
1,849✔
3351
        for _, t := range stmt.targets {
888✔
3352
                selectors = append(selectors, t.Exp.selectors()...)
888✔
3353
        }
961✔
3354
        return selectors
3355
}
3356

968✔
3357
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
968✔
3358
        scanSpecs, err := stmt.genScanSpecs(tx, params)
984✔
3359
        if err != nil {
16✔
3360
                return nil, err
16✔
3361
        }
3362

952✔
3363
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
955✔
3364
        if err != nil {
3✔
3365
                return nil, err
3✔
3366
        }
1,898✔
3367
        defer func() {
956✔
3368
                if err != nil {
7✔
3369
                        rowReader.Close()
7✔
3370
                }
3371
        }()
3372

964✔
3373
        if stmt.joins != nil {
15✔
3374
                var jointRowReader *jointRowReader
15✔
3375
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
16✔
3376
                if err != nil {
1✔
3377
                        return nil, err
1✔
3378
                }
14✔
3379
                rowReader = jointRowReader
3380
        }
3381

1,477✔
3382
        if stmt.where != nil {
529✔
3383
                rowReader = newConditionalRowReader(rowReader, stmt.where)
529✔
3384
        }
3385

1,028✔
3386
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
92✔
3387
                if len(scanSpecs.groupBySortExps) > 0 {
12✔
3388
                        var sortRowReader *sortRowReader
12✔
3389
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortExps)
12✔
UNCOV
3390
                        if err != nil {
×
3391
                                return nil, err
×
3392
                        }
12✔
3393
                        rowReader = sortRowReader
3394
                }
3395

80✔
3396
                var groupedRowReader *groupedRowReader
80✔
3397
                groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy)
82✔
3398
                if err != nil {
2✔
3399
                        return nil, err
2✔
3400
                }
78✔
3401
                rowReader = groupedRowReader
78✔
3402

82✔
3403
                if stmt.having != nil {
4✔
3404
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
3405
                }
3406
        }
3407

995✔
3408
        if len(scanSpecs.orderBySortExps) > 0 {
49✔
3409
                var sortRowReader *sortRowReader
49✔
3410
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
50✔
3411
                if err != nil {
1✔
3412
                        return nil, err
1✔
3413
                }
48✔
3414
                rowReader = sortRowReader
3415
        }
3416

945✔
3417
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets)
948✔
3418
        if err != nil {
3✔
3419
                return nil, err
3✔
3420
        }
942✔
3421
        rowReader = projectedRowReader
942✔
3422

949✔
3423
        if stmt.distinct {
7✔
3424
                var distinctRowReader *distinctRowReader
7✔
3425
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
7✔
UNCOV
3426
                if err != nil {
×
3427
                        return nil, err
×
3428
                }
7✔
3429
                rowReader = distinctRowReader
3430
        }
3431

992✔
3432
        if stmt.offset != nil {
50✔
3433
                var offset int
50✔
3434
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
UNCOV
3435
                if err != nil {
×
3436
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
3437
                }
3438

50✔
3439
                rowReader = newOffsetRowReader(rowReader, offset)
3440
        }
3441

1,039✔
3442
        if stmt.limit != nil {
97✔
3443
                var limit int
97✔
3444
                limit, err = evalExpAsInt(tx, stmt.limit, params)
97✔
UNCOV
3445
                if err != nil {
×
3446
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
3447
                }
3448

97✔
UNCOV
3449
                if limit < 0 {
×
3450
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
3451
                }
3452

140✔
3453
                if limit > 0 {
43✔
3454
                        rowReader = newLimitRowReader(rowReader, limit)
43✔
3455
                }
3456
        }
942✔
3457
        return rowReader, nil
3458
}
3459

951✔
3460
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
957✔
3461
        if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
8✔
3462
                if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
2✔
3463
                        return orderByExps, nil
2✔
3464
                }
3465

5✔
3466
                if ordExpsHasPrefix(groupByCols, orderByExps, stmt.Alias()) {
2✔
3467
                        for i := range orderByExps {
1✔
3468
                                groupByCols[i].descOrder = orderByExps[i].descOrder
1✔
3469
                        }
1✔
3470
                        return groupByCols, nil
3471
                }
3472
        }
948✔
3473
        return groupByCols, orderByExps
3474
}
3475

10✔
3476
func ordExpsHasPrefix(cols, prefix []*OrdExp, table string) bool {
12✔
3477
        if len(prefix) > len(cols) {
2✔
3478
                return false
2✔
3479
        }
3480

17✔
3481
        for i := range prefix {
9✔
3482
                ls := prefix[i].AsSelector()
9✔
3483
                rs := cols[i].AsSelector()
9✔
3484

9✔
UNCOV
3485
                if ls == nil || rs == nil {
×
3486
                        return false
×
3487
                }
3488

14✔
3489
                if EncodeSelector(ls.resolve(table)) != EncodeSelector(rs.resolve(table)) {
5✔
3490
                        return false
5✔
3491
                }
3492
        }
3✔
3493
        return true
3494
}
3495

968✔
3496
func (stmt *SelectStmt) groupByOrdExps() []*OrdExp {
968✔
3497
        groupByCols := stmt.groupBy
968✔
3498

968✔
3499
        ordExps := make([]*OrdExp, 0, len(groupByCols))
1,015✔
3500
        for _, col := range groupByCols {
47✔
3501
                ordExps = append(ordExps, &OrdExp{exp: col})
47✔
3502
        }
968✔
3503
        return ordExps
3504
}
3505

7✔
3506
func ordExpsHaveAggregations(exps []*OrdExp) bool {
17✔
3507
        for _, e := range exps {
11✔
3508
                if _, isAgg := e.exp.(*AggColSelector); isAgg {
1✔
3509
                        return true
1✔
3510
                }
3511
        }
6✔
3512
        return false
3513
}
3514

1,577✔
3515
func (stmt *SelectStmt) containsAggregations() bool {
3,173✔
3516
        for _, sel := range stmt.targetSelectors() {
1,596✔
3517
                _, isAgg := sel.(*AggColSelector)
1,759✔
3518
                if isAgg {
163✔
3519
                        return true
163✔
3520
                }
3521
        }
1,414✔
3522
        return false
3523
}
3524

147✔
3525
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
147✔
3526
        offset, err := exp.substitute(params)
147✔
UNCOV
3527
        if err != nil {
×
3528
                return 0, err
×
3529
        }
3530

147✔
3531
        texp, err := offset.reduce(tx, nil, "")
147✔
UNCOV
3532
        if err != nil {
×
3533
                return 0, err
×
3534
        }
3535

147✔
3536
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
147✔
UNCOV
3537
        if err != nil {
×
3538
                return 0, ErrInvalidValue
×
3539
        }
3540

147✔
3541
        num, ok := convVal.(int64)
147✔
UNCOV
3542
        if !ok {
×
3543
                return 0, ErrInvalidValue
×
3544
        }
3545

147✔
UNCOV
3546
        if num > math.MaxInt32 {
×
3547
                return 0, ErrInvalidValue
×
3548
        }
3549

147✔
3550
        return int(num), nil
3551
}
3552

167✔
3553
func (stmt *SelectStmt) Alias() string {
333✔
3554
        if stmt.as == "" {
166✔
3555
                return stmt.ds.Alias()
166✔
3556
        }
3557

1✔
3558
        return stmt.as
3559
}
3560

876✔
3561
func (stmt *SelectStmt) hasTxMetadata() bool {
1,692✔
3562
        for _, sel := range stmt.targetSelectors() {
816✔
3563
                switch s := sel.(type) {
702✔
3564
                case *ColSelector:
703✔
3565
                        if s.col == txMetadataCol {
1✔
3566
                                return true
1✔
3567
                        }
21✔
3568
                case *JSONSelector:
24✔
3569
                        if s.ColSelector.col == txMetadataCol {
3✔
3570
                                return true
3✔
3571
                        }
3572
                }
3573
        }
872✔
3574
        return false
3575
}
3576

968✔
3577
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
968✔
3578
        groupByCols, orderByCols := stmt.groupByOrdExps(), stmt.orderBy
968✔
3579

968✔
3580
        tableRef, isTableRef := stmt.ds.(*tableRef)
1,043✔
3581
        if !isTableRef {
75✔
3582
                groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
75✔
3583

75✔
3584
                return &ScanSpecs{
75✔
3585
                        groupBySortExps: groupByCols,
75✔
3586
                        orderBySortExps: orderByCols,
75✔
3587
                }, nil
75✔
3588
        }
3589

893✔
3590
        table, err := tableRef.referencedTable(tx)
908✔
3591
        if err != nil {
16✔
3592
                if tx.engine.tableResolveFor(tableRef.table) != nil {
1✔
3593
                        return &ScanSpecs{
1✔
3594
                                groupBySortExps: groupByCols,
1✔
3595
                                orderBySortExps: orderByCols,
1✔
3596
                        }, nil
1✔
3597
                }
14✔
3598
                return nil, err
3599
        }
3600

878✔
3601
        rangesByColID := make(map[uint32]*typedValueRange)
1,396✔
3602
        if stmt.where != nil {
518✔
3603
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
520✔
3604
                if err != nil {
2✔
3605
                        return nil, err
2✔
3606
                }
3607
        }
3608

876✔
3609
        preferredIndex, err := stmt.getPreferredIndex(table)
876✔
UNCOV
3610
        if err != nil {
×
3611
                return nil, err
×
3612
        }
3613

876✔
3614
        var sortingIndex *Index
1,722✔
3615
        if preferredIndex == nil {
846✔
3616
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
876✔
3617
        } else {
30✔
3618
                sortingIndex = preferredIndex
30✔
3619
        }
3620

1,629✔
3621
        if sortingIndex == nil {
753✔
3622
                sortingIndex = table.primaryIndex
753✔
3623
        }
3624

876✔
UNCOV
3625
        if tableRef.history && !sortingIndex.IsPrimary() {
×
3626
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3627
        }
3628

876✔
3629
        var descOrder bool
893✔
3630
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
17✔
3631
                groupByCols = nil
17✔
3632
        }
3633

978✔
3634
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
102✔
3635
                descOrder = orderByCols[0].descOrder
102✔
3636
                orderByCols = nil
102✔
3637
        }
3638

876✔
3639
        groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
876✔
3640

876✔
3641
        return &ScanSpecs{
876✔
3642
                Index:             sortingIndex,
876✔
3643
                rangesByColID:     rangesByColID,
876✔
3644
                IncludeHistory:    tableRef.history,
876✔
3645
                IncludeTxMetadata: stmt.hasTxMetadata(),
876✔
3646
                DescOrder:         descOrder,
876✔
3647
                groupBySortExps:   groupByCols,
876✔
3648
                orderBySortExps:   orderByCols,
876✔
3649
        }, nil
3650
}
3651

846✔
3652
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdExp, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
846✔
3653
        sortCols := groupByCols
1,666✔
3654
        if len(sortCols) == 0 {
820✔
3655
                sortCols = orderByCols
820✔
3656
        }
3657

1,553✔
3658
        if len(sortCols) == 0 {
707✔
3659
                return nil
707✔
3660
        }
3661

364✔
3662
        for _, idx := range table.indexes {
318✔
3663
                if idx.coversOrdCols(sortCols, rangesByColId) {
93✔
3664
                        return idx
93✔
3665
                }
3666
        }
46✔
3667
        return nil
3668
}
3669

876✔
3670
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
1,722✔
3671
        if len(stmt.indexOn) == 0 {
846✔
3672
                return nil, nil
846✔
3673
        }
3674

30✔
3675
        cols := make([]*Column, len(stmt.indexOn))
80✔
3676
        for i, colName := range stmt.indexOn {
50✔
3677
                col, err := table.GetColumnByName(colName)
50✔
UNCOV
3678
                if err != nil {
×
3679
                        return nil, err
×
3680
                }
3681

50✔
3682
                cols[i] = col
3683
        }
30✔
3684
        return table.GetIndexByName(indexName(table.name, cols))
3685
}
3686

3687
type UnionStmt struct {
3688
        distinct    bool
3689
        left, right DataSource
3690
}
3691

1✔
3692
func (stmt *UnionStmt) readOnly() bool {
1✔
3693
        return true
1✔
3694
}
3695

1✔
3696
func (stmt *UnionStmt) requiredPrivileges() []SQLPrivilege {
1✔
3697
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3698
}
3699

1✔
3700
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3701
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
UNCOV
3702
        if err != nil {
×
3703
                return err
×
3704
        }
1✔
3705
        return stmt.right.inferParameters(ctx, tx, params)
3706
}
3707

9✔
3708
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3709
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
UNCOV
3710
        if err != nil {
×
3711
                return tx, err
×
3712
        }
3713

9✔
3714
        return stmt.right.execAt(ctx, tx, params)
3715
}
3716

11✔
3717
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3718
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
12✔
3719
        if err != nil {
1✔
3720
                return nil, err
1✔
3721
        }
20✔
3722
        defer func() {
14✔
3723
                if err != nil {
4✔
3724
                        leftRowReader.Close()
4✔
3725
                }
3726
        }()
3727

10✔
3728
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
11✔
3729
        if err != nil {
1✔
3730
                return nil, err
1✔
3731
        }
18✔
3732
        defer func() {
12✔
3733
                if err != nil {
3✔
3734
                        rightRowReader.Close()
3✔
3735
                }
3736
        }()
3737

9✔
3738
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
12✔
3739
        if err != nil {
3✔
3740
                return nil, err
3✔
3741
        }
3742

6✔
3743
        return rowReader, nil
3744
}
3745

11✔
3746
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3747
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
16✔
3748
        if err != nil {
5✔
3749
                return nil, err
5✔
3750
        }
12✔
3751
        defer func() {
7✔
3752
                if err != nil {
1✔
3753
                        rowReader.Close()
1✔
3754
                }
3755
        }()
3756

11✔
3757
        if stmt.distinct {
5✔
3758
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
6✔
3759
                if err != nil {
1✔
3760
                        return nil, err
1✔
3761
                }
4✔
3762
                rowReader = distinctReader
3763
        }
3764

5✔
3765
        return rowReader, nil
3766
}
UNCOV
3767

×
3768
func (stmt *UnionStmt) Alias() string {
×
3769
        return ""
×
3770
}
3771

179✔
3772
func NewTableRef(table string, as string) *tableRef {
179✔
3773
        return &tableRef{
179✔
3774
                table: table,
179✔
3775
                as:    as,
179✔
3776
        }
179✔
3777
}
3778

3779
type tableRef struct {
3780
        table   string
3781
        history bool
3782
        period  period
3783
        as      string
3784
}
3785

1✔
3786
func (ref *tableRef) readOnly() bool {
1✔
3787
        return true
1✔
3788
}
3789

1✔
3790
func (ref *tableRef) requiredPrivileges() []SQLPrivilege {
1✔
3791
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3792
}
3793

3794
type period struct {
3795
        start *openPeriod
3796
        end   *openPeriod
3797
}
3798

3799
type openPeriod struct {
3800
        inclusive bool
3801
        instant   periodInstant
3802
}
3803

3804
type periodInstant struct {
3805
        exp         ValueExp
3806
        instantType instantType
3807
}
3808

3809
type instantType = int
3810

3811
const (
3812
        txInstant instantType = iota
3813
        timeInstant
3814
)
3815

81✔
3816
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
81✔
3817
        exp, err := i.exp.substitute(params)
81✔
UNCOV
3818
        if err != nil {
×
3819
                return 0, err
×
3820
        }
3821

81✔
3822
        instantVal, err := exp.reduce(tx, nil, "")
83✔
3823
        if err != nil {
2✔
3824
                return 0, err
2✔
3825
        }
3826

124✔
3827
        if i.instantType == txInstant {
45✔
3828
                txID, ok := instantVal.RawValue().(int64)
45✔
UNCOV
3829
                if !ok {
×
3830
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3831
                }
3832

52✔
3833
                if txID <= 0 {
7✔
3834
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3835
                }
3836

61✔
3837
                if inclusive {
23✔
3838
                        return uint64(txID), nil
23✔
3839
                }
3840

26✔
3841
                if asc {
11✔
3842
                        return uint64(txID + 1), nil
11✔
3843
                }
3844

5✔
3845
                if txID <= 1 {
1✔
3846
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3847
                }
3848

3✔
3849
                return uint64(txID - 1), nil
34✔
3850
        } else {
34✔
3851

34✔
3852
                var ts time.Time
34✔
3853

67✔
3854
                if instantVal.Type() == TimestampType {
33✔
3855
                        ts = instantVal.RawValue().(time.Time)
34✔
3856
                } else {
1✔
3857
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
UNCOV
3858
                        if err != nil {
×
3859
                                return 0, err
×
3860
                        }
3861

1✔
3862
                        tval, err := conv(instantVal)
1✔
UNCOV
3863
                        if err != nil {
×
3864
                                return 0, err
×
3865
                        }
3866

1✔
3867
                        ts = tval.RawValue().(time.Time)
3868
                }
3869

34✔
3870
                sts := ts
34✔
3871

57✔
3872
                if asc {
34✔
3873
                        if !inclusive {
11✔
3874
                                sts = sts.Add(1 * time.Second)
11✔
3875
                        }
3876

23✔
3877
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
34✔
3878
                        if err != nil {
11✔
3879
                                return 0, err
11✔
3880
                        }
3881

12✔
3882
                        return txHdr.ID, nil
3883
                }
3884

11✔
UNCOV
3885
                if !inclusive {
×
3886
                        sts = sts.Add(-1 * time.Second)
×
3887
                }
3888

11✔
3889
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
UNCOV
3890
                if err != nil {
×
3891
                        return 0, err
×
3892
                }
3893

11✔
3894
                return txHdr.ID, nil
3895
        }
3896
}
3897

3,943✔
3898
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
3,943✔
3899
        table, err := tx.catalog.GetTableByName(stmt.table)
3,963✔
3900
        if err != nil {
20✔
3901
                return nil, err
20✔
3902
        }
3,923✔
3903
        return table, nil
3904
}
3905

1✔
3906
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3907
        return nil
1✔
3908
}
UNCOV
3909

×
3910
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3911
        return tx, nil
×
3912
}
3913

906✔
3914
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
906✔
UNCOV
3915
        if tx == nil {
×
3916
                return nil, ErrIllegalArguments
×
3917
        }
3918

906✔
3919
        table, err := stmt.referencedTable(tx)
1,811✔
3920
        if err == nil {
905✔
3921
                return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
905✔
3922
        }
3923

2✔
3924
        if resolver := tx.engine.tableResolveFor(stmt.table); resolver != nil {
1✔
3925
                return resolver.Resolve(ctx, tx, stmt.Alias())
1✔
UNCOV
3926
        }
×
3927
        return nil, err
3928
}
3929

683✔
3930
func (stmt *tableRef) Alias() string {
1,206✔
3931
        if stmt.as == "" {
523✔
3932
                return stmt.table
523✔
3933
        }
160✔
3934
        return stmt.as
3935
}
3936

3937
type valuesDataSource struct {
3938
        inferTypes bool
3939
        rows       []*RowSpec
3940
}
3941

120✔
3942
func NewValuesDataSource(rows []*RowSpec) *valuesDataSource {
120✔
3943
        return &valuesDataSource{
120✔
3944
                rows: rows,
120✔
3945
        }
120✔
3946
}
UNCOV
3947

×
3948
func (ds *valuesDataSource) readOnly() bool {
×
3949
        return true
×
3950
}
3951

97✔
3952
func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege {
97✔
3953
        return nil
97✔
3954
}
UNCOV
3955

×
3956
func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3957
        return tx, nil
×
3958
}
UNCOV
3959

×
3960
func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
3961
        return nil
×
3962
}
UNCOV
3963

×
3964
func (ds *valuesDataSource) Alias() string {
×
3965
        return ""
×
3966
}
3967

2,132✔
3968
func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
2,132✔
UNCOV
3969
        if tx == nil {
×
3970
                return nil, ErrIllegalArguments
×
3971
        }
3972

2,132✔
3973
        cols := make([]ColDescriptor, len(ds.rows[0].Values))
9,788✔
3974
        for i := range cols {
7,656✔
3975
                cols[i] = ColDescriptor{
7,656✔
3976
                        Type:   AnyType,
7,656✔
3977
                        Column: fmt.Sprintf("col%d", i),
7,656✔
3978
                }
7,656✔
3979
        }
3980

2,132✔
3981
        emptyColsDesc, emptyParams := map[string]ColDescriptor{}, map[string]string{}
2,132✔
3982

2,145✔
3983
        if ds.inferTypes {
56✔
3984
                for i := 0; i < len(cols); i++ {
43✔
3985
                        t := AnyType
154✔
3986
                        for j := 0; j < len(ds.rows); j++ {
111✔
3987
                                e, err := ds.rows[j].Values[i].substitute(params)
111✔
UNCOV
3988
                                if err != nil {
×
3989
                                        return nil, err
×
3990
                                }
3991

111✔
3992
                                it, err := e.inferType(emptyColsDesc, emptyParams, "")
111✔
UNCOV
3993
                                if err != nil {
×
3994
                                        return nil, err
×
3995
                                }
3996

154✔
3997
                                if t == AnyType {
43✔
3998
                                        t = it
113✔
3999
                                } else if t != it && it != AnyType {
2✔
4000
                                        return nil, fmt.Errorf("cannot match types %s and %s", t, it)
2✔
4001
                                }
4002
                        }
41✔
4003
                        cols[i].Type = t
4004
                }
4005
        }
4006

2,130✔
4007
        values := make([][]ValueExp, len(ds.rows))
4,369✔
4008
        for i, rowSpec := range ds.rows {
2,239✔
4009
                values[i] = rowSpec.Values
2,239✔
4010
        }
2,130✔
4011
        return NewValuesRowReader(tx, params, cols, ds.inferTypes, "values", values)
4012
}
4013

4014
type JoinSpec struct {
4015
        joinType JoinType
4016
        ds       DataSource
4017
        cond     ValueExp
4018
        indexOn  []string
4019
}
4020

4021
type OrdExp struct {
4022
        exp       ValueExp
4023
        descOrder bool
4024
}
4025

714✔
4026
func (oc *OrdExp) AsSelector() Selector {
714✔
4027
        sel, ok := oc.exp.(Selector)
1,374✔
4028
        if ok {
660✔
4029
                return sel
660✔
4030
        }
54✔
4031
        return nil
4032
}
4033

1✔
4034
func NewOrdCol(table string, col string, descOrder bool) *OrdExp {
1✔
4035
        return &OrdExp{
1✔
4036
                exp:       NewColSelector(table, col),
1✔
4037
                descOrder: descOrder,
1✔
4038
        }
1✔
4039
}
4040

4041
type Selector interface {
4042
        ValueExp
4043
        resolve(implicitTable string) (aggFn, table, col string)
4044
}
4045

4046
type ColSelector struct {
4047
        table string
4048
        col   string
4049
}
4050

126✔
4051
func NewColSelector(table, col string) *ColSelector {
126✔
4052
        return &ColSelector{
126✔
4053
                table: table,
126✔
4054
                col:   col,
126✔
4055
        }
126✔
4056
}
4057

867,667✔
4058
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
867,667✔
4059
        table = implicitTable
1,168,687✔
4060
        if sel.table != "" {
301,020✔
4061
                table = sel.table
301,020✔
4062
        }
867,667✔
4063
        return "", table, sel.col
4064
}
4065

685✔
4066
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
685✔
4067
        _, table, col := sel.resolve(implicitTable)
685✔
4068
        encSel := EncodeSelector("", table, col)
685✔
4069

685✔
4070
        desc, ok := cols[encSel]
688✔
4071
        if !ok {
3✔
4072
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
4073
        }
682✔
4074
        return desc.Type, nil
4075
}
4076

15✔
4077
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
4078
        _, table, col := sel.resolve(implicitTable)
15✔
4079
        encSel := EncodeSelector("", table, col)
15✔
4080

15✔
4081
        desc, ok := cols[encSel]
17✔
4082
        if !ok {
2✔
4083
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
4084
        }
4085

16✔
4086
        if desc.Type != t {
3✔
4087
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
4088
        }
4089

10✔
4090
        return nil
4091
}
4092

161,870✔
4093
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
161,870✔
4094
        return sel, nil
161,870✔
4095
}
4096

715,151✔
4097
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
715,152✔
4098
        if row == nil {
1✔
4099
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
4100
        }
4101

715,150✔
4102
        aggFn, table, col := sel.resolve(implicitTable)
715,150✔
4103

715,150✔
4104
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
715,157✔
4105
        if !ok {
7✔
4106
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
4107
        }
715,143✔
4108
        return v, nil
4109
}
4110

926✔
4111
func (sel *ColSelector) selectors() []Selector {
926✔
4112
        return []Selector{sel}
926✔
4113
}
4114

568✔
4115
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
568✔
4116
        aggFn, table, col := sel.resolve(implicitTable)
568✔
4117

568✔
4118
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
846✔
4119
        if !ok {
278✔
4120
                return sel
278✔
4121
        }
4122

290✔
4123
        return v
4124
}
4125

12✔
4126
func (sel *ColSelector) isConstant() bool {
12✔
4127
        return false
12✔
4128
}
4129

11✔
4130
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
4131
        return nil
11✔
4132
}
4133

48✔
4134
func (sel *ColSelector) String() string {
48✔
4135
        return sel.col
48✔
4136
}
4137

4138
type AggColSelector struct {
4139
        aggFn AggregateFn
4140
        table string
4141
        col   string
4142
}
4143

16✔
4144
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
4145
        return &AggColSelector{
16✔
4146
                aggFn: aggFn,
16✔
4147
                table: table,
16✔
4148
                col:   col,
16✔
4149
        }
16✔
4150
}
4151

1,417,514✔
4152
func EncodeSelector(aggFn, table, col string) string {
1,417,514✔
4153
        return aggFn + "(" + table + "." + col + ")"
1,417,514✔
4154
}
4155

1,586✔
4156
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,586✔
4157
        table = implicitTable
1,717✔
4158
        if sel.table != "" {
131✔
4159
                table = sel.table
131✔
4160
        }
1,586✔
4161
        return sel.aggFn, table, sel.col
4162
}
4163

36✔
4164
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
55✔
4165
        if sel.aggFn == COUNT {
19✔
4166
                return IntegerType, nil
19✔
4167
        }
4168

17✔
4169
        colSelector := &ColSelector{table: sel.table, col: sel.col}
17✔
4170

24✔
4171
        if sel.aggFn == SUM || sel.aggFn == AVG {
7✔
4172
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
UNCOV
4173
                if err != nil {
×
4174
                        return AnyType, err
×
4175
                }
4176

7✔
UNCOV
4177
                if t != IntegerType && t != Float64Type {
×
4178
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
4179

×
4180
                }
4181

7✔
4182
                return t, nil
4183
        }
4184

10✔
4185
        return colSelector.inferType(cols, params, implicitTable)
4186
}
4187

8✔
4188
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
10✔
4189
        if sel.aggFn == COUNT {
3✔
4190
                if t != IntegerType {
1✔
4191
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4192
                }
1✔
4193
                return nil
4194
        }
4195

6✔
4196
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
4197

10✔
4198
        if sel.aggFn == SUM || sel.aggFn == AVG {
5✔
4199
                if t != IntegerType && t != Float64Type {
1✔
4200
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
4201
                }
4202
        }
4203

5✔
4204
        return colSelector.requiresType(t, cols, params, implicitTable)
4205
}
4206

412✔
4207
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
412✔
4208
        return sel, nil
412✔
4209
}
4210

459✔
4211
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
460✔
4212
        if row == nil {
1✔
4213
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
4214
        }
4215

458✔
4216
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
459✔
4217
        if !ok {
1✔
4218
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
4219
        }
457✔
4220
        return v, nil
4221
}
4222

232✔
4223
func (sel *AggColSelector) selectors() []Selector {
232✔
4224
        return []Selector{sel}
232✔
4225
}
UNCOV
4226

×
4227
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4228
        return sel
×
4229
}
4230

1✔
4231
func (sel *AggColSelector) isConstant() bool {
1✔
4232
        return false
1✔
4233
}
UNCOV
4234

×
4235
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4236
        return nil
×
4237
}
UNCOV
4238

×
4239
func (sel *AggColSelector) String() string {
×
4240
        return sel.aggFn + "(" + sel.col + ")"
×
4241
}
4242

4243
type NumExp struct {
4244
        op          NumOperator
4245
        left, right ValueExp
4246
}
4247

17✔
4248
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
17✔
4249
        // First step - check if we can infer the type of sub-expressions
17✔
4250
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
17✔
UNCOV
4251
        if err != nil {
×
4252
                return AnyType, err
×
4253
        }
17✔
UNCOV
4254
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType {
×
4255
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
4256
        }
4257

17✔
4258
        tright, err := bexp.right.inferType(cols, params, implicitTable)
17✔
UNCOV
4259
        if err != nil {
×
4260
                return AnyType, err
×
4261
        }
19✔
4262
        if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType {
2✔
4263
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
4264
        }
4265

19✔
4266
        if tleft == IntegerType && tright == IntegerType {
4✔
4267
                // Both sides are integer types - the result is also integer
4✔
4268
                return IntegerType, nil
4✔
4269
        }
4270

20✔
4271
        if tleft != AnyType && tright != AnyType {
9✔
4272
                // Both sides have concrete types but at least one of them is float
9✔
4273
                return Float64Type, nil
9✔
4274
        }
4275

4276
        // Both sides are ambiguous
2✔
4277
        return AnyType, nil
4278
}
4279

11✔
4280
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
4281
        ret := make(map[string]SQLValueType, len(params))
15✔
4282
        for k, v := range params {
4✔
4283
                ret[k] = v
4✔
4284
        }
11✔
4285
        return ret
4286
}
4287

2✔
4288
func restoreParams(params, restore map[string]SQLValueType) {
2✔
UNCOV
4289
        for k := range params {
×
4290
                delete(params, k)
×
4291
        }
2✔
UNCOV
4292
        for k, v := range restore {
×
4293
                params[k] = v
×
4294
        }
4295
}
4296

7✔
4297
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
4298
        if t != IntegerType && t != Float64Type {
1✔
4299
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4300
        }
4301

6✔
4302
        floatArgs := 2
6✔
4303
        paramsOrig := copyParams(params)
6✔
4304
        err := bexp.left.requiresType(t, cols, params, implicitTable)
7✔
4305
        if err != nil && t == Float64Type {
1✔
4306
                restoreParams(params, paramsOrig)
1✔
4307
                floatArgs--
1✔
4308
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
4309
        }
7✔
4310
        if err != nil {
1✔
4311
                return err
1✔
4312
        }
4313

5✔
4314
        paramsOrig = copyParams(params)
5✔
4315
        err = bexp.right.requiresType(t, cols, params, implicitTable)
6✔
4316
        if err != nil && t == Float64Type {
1✔
4317
                restoreParams(params, paramsOrig)
1✔
4318
                floatArgs--
1✔
4319
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
4320
        }
7✔
4321
        if err != nil {
2✔
4322
                return err
2✔
4323
        }
4324

3✔
UNCOV
4325
        if t == Float64Type && floatArgs == 0 {
×
4326
                // Currently this case requires explicit float cast
×
4327
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
4328
        }
4329

3✔
4330
        return nil
4331
}
4332

187✔
4333
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
187✔
4334
        rlexp, err := bexp.left.substitute(params)
187✔
UNCOV
4335
        if err != nil {
×
4336
                return nil, err
×
4337
        }
4338

187✔
4339
        rrexp, err := bexp.right.substitute(params)
187✔
UNCOV
4340
        if err != nil {
×
4341
                return nil, err
×
4342
        }
4343

187✔
4344
        bexp.left = rlexp
187✔
4345
        bexp.right = rrexp
187✔
4346

187✔
4347
        return bexp, nil
4348
}
4349

124,242✔
4350
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
124,242✔
4351
        vl, err := bexp.left.reduce(tx, row, implicitTable)
124,242✔
UNCOV
4352
        if err != nil {
×
4353
                return nil, err
×
4354
        }
4355

124,242✔
4356
        vr, err := bexp.right.reduce(tx, row, implicitTable)
124,242✔
UNCOV
4357
        if err != nil {
×
4358
                return nil, err
×
4359
        }
4360

124,242✔
4361
        vl = unwrapJSON(vl)
124,242✔
4362
        vr = unwrapJSON(vr)
124,242✔
4363

124,242✔
4364
        return applyNumOperator(bexp.op, vl, vr)
4365
}
4366

248,484✔
4367
func unwrapJSON(v TypedValue) TypedValue {
248,584✔
4368
        if jsonVal, ok := v.(*JSON); ok {
200✔
4369
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
100✔
4370
                        return sv
100✔
4371
                }
4372
        }
248,384✔
4373
        return v
4374
}
4375

13✔
4376
func (bexp *NumExp) selectors() []Selector {
13✔
4377
        return append(bexp.left.selectors(), bexp.right.selectors()...)
13✔
4378
}
4379

1✔
4380
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4381
        return &NumExp{
1✔
4382
                op:    bexp.op,
1✔
4383
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
4384
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
4385
        }
1✔
4386
}
4387

5✔
4388
func (bexp *NumExp) isConstant() bool {
5✔
4389
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
4390
}
4391

4✔
4392
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
4✔
4393
        return nil
4✔
4394
}
4395

18✔
4396
func (bexp *NumExp) String() string {
18✔
4397
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
18✔
4398
}
4399

4400
type NotBoolExp struct {
4401
        exp ValueExp
4402
}
4403

1✔
4404
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4405
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
1✔
UNCOV
4406
        if err != nil {
×
4407
                return AnyType, err
×
4408
        }
4409

1✔
4410
        return BooleanType, nil
4411
}
4412

6✔
4413
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4414
        if t != BooleanType {
1✔
4415
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4416
        }
4417

5✔
4418
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
4419
}
4420

22✔
4421
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
4422
        rexp, err := bexp.exp.substitute(params)
22✔
UNCOV
4423
        if err != nil {
×
4424
                return nil, err
×
4425
        }
4426

22✔
4427
        bexp.exp = rexp
22✔
4428

22✔
4429
        return bexp, nil
4430
}
4431

22✔
4432
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
4433
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
UNCOV
4434
        if err != nil {
×
4435
                return nil, err
×
4436
        }
4437

22✔
4438
        r, isBool := v.RawValue().(bool)
22✔
UNCOV
4439
        if !isBool {
×
4440
                return nil, ErrInvalidCondition
×
4441
        }
4442

22✔
4443
        return &Bool{val: !r}, nil
4444
}
UNCOV
4445

×
4446
func (bexp *NotBoolExp) selectors() []Selector {
×
4447
        return bexp.exp.selectors()
×
4448
}
UNCOV
4449

×
4450
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4451
        return &NotBoolExp{
×
4452
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
4453
        }
×
4454
}
4455

1✔
4456
func (bexp *NotBoolExp) isConstant() bool {
1✔
4457
        return bexp.exp.isConstant()
1✔
4458
}
4459

7✔
4460
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
4461
        return nil
7✔
4462
}
4463

12✔
4464
func (bexp *NotBoolExp) String() string {
12✔
4465
        return fmt.Sprintf("(NOT %s)", bexp.exp.String())
12✔
4466
}
4467

4468
type LikeBoolExp struct {
4469
        val     ValueExp
4470
        notLike bool
4471
        pattern ValueExp
4472
}
4473

4✔
4474
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
4✔
4475
        return &LikeBoolExp{
4✔
4476
                val:     val,
4✔
4477
                notLike: notLike,
4✔
4478
                pattern: pattern,
4✔
4479
        }
4✔
4480
}
4481

3✔
4482
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
4✔
4483
        if bexp.val == nil || bexp.pattern == nil {
1✔
4484
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4485
        }
4486

2✔
4487
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
4488
        if err != nil {
1✔
4489
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4490
        }
4491

1✔
4492
        return BooleanType, nil
4493
}
4494

7✔
4495
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
9✔
4496
        if bexp.val == nil || bexp.pattern == nil {
2✔
4497
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
4498
        }
4499

7✔
4500
        if t != BooleanType {
2✔
4501
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
4502
        }
4503

3✔
4504
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
4✔
4505
        if err != nil {
1✔
4506
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4507
        }
4508

2✔
4509
        return nil
4510
}
4511

135✔
4512
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
136✔
4513
        if bexp.val == nil || bexp.pattern == nil {
1✔
4514
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4515
        }
4516

134✔
4517
        val, err := bexp.val.substitute(params)
134✔
UNCOV
4518
        if err != nil {
×
4519
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4520
        }
4521

134✔
4522
        pattern, err := bexp.pattern.substitute(params)
134✔
UNCOV
4523
        if err != nil {
×
4524
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4525
        }
4526

134✔
4527
        return &LikeBoolExp{
134✔
4528
                val:     val,
134✔
4529
                notLike: bexp.notLike,
134✔
4530
                pattern: pattern,
134✔
4531
        }, nil
4532
}
4533

141✔
4534
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
142✔
4535
        if bexp.val == nil || bexp.pattern == nil {
1✔
4536
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4537
        }
4538

140✔
4539
        rval, err := bexp.val.reduce(tx, row, implicitTable)
140✔
UNCOV
4540
        if err != nil {
×
4541
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4542
        }
4543

141✔
4544
        if rval.IsNull() {
1✔
4545
                return &Bool{val: bexp.notLike}, nil
1✔
4546
        }
4547

139✔
4548
        rvalStr, ok := rval.RawValue().(string)
140✔
4549
        if !ok {
1✔
4550
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
4551
        }
4552

138✔
4553
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
138✔
UNCOV
4554
        if err != nil {
×
4555
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4556
        }
4557

138✔
UNCOV
4558
        if rpattern.Type() != VarcharType {
×
4559
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
4560
        }
4561

138✔
4562
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
138✔
UNCOV
4563
        if err != nil {
×
4564
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4565
        }
4566

138✔
4567
        return &Bool{val: matched != bexp.notLike}, nil
4568
}
4569

1✔
4570
func (bexp *LikeBoolExp) selectors() []Selector {
1✔
4571
        return bexp.val.selectors()
1✔
4572
}
4573

1✔
4574
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4575
        return bexp
1✔
4576
}
4577

2✔
4578
func (bexp *LikeBoolExp) isConstant() bool {
2✔
4579
        return false
2✔
4580
}
4581

8✔
4582
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
4583
        return nil
8✔
4584
}
4585

5✔
4586
func (bexp *LikeBoolExp) String() string {
5✔
4587
        fmtStr := "(%s LIKE %s)"
6✔
4588
        if bexp.notLike {
1✔
4589
                fmtStr = "(%s NOT LIKE %s)"
1✔
4590
        }
5✔
4591
        return fmt.Sprintf(fmtStr, bexp.val.String(), bexp.pattern.String())
4592
}
4593

4594
type CmpBoolExp struct {
4595
        op          CmpOperator
4596
        left, right ValueExp
4597
}
4598

67✔
4599
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
67✔
4600
        return &CmpBoolExp{
67✔
4601
                op:    op,
67✔
4602
                left:  left,
67✔
4603
                right: right,
67✔
4604
        }
67✔
4605
}
UNCOV
4606

×
4607
func (bexp *CmpBoolExp) Left() ValueExp {
×
4608
        return bexp.left
×
4609
}
UNCOV
4610

×
4611
func (bexp *CmpBoolExp) Right() ValueExp {
×
4612
        return bexp.right
×
4613
}
UNCOV
4614

×
4615
func (bexp *CmpBoolExp) OP() CmpOperator {
×
4616
        return bexp.op
×
4617
}
4618

63✔
4619
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
63✔
4620
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
63✔
UNCOV
4621
        if err != nil {
×
4622
                return AnyType, err
×
4623
        }
4624

63✔
4625
        tright, err := bexp.right.inferType(cols, params, implicitTable)
65✔
4626
        if err != nil {
2✔
4627
                return AnyType, err
2✔
4628
        }
4629

4630
        // unification step
4631

74✔
4632
        if tleft == tright {
13✔
4633
                return BooleanType, nil
13✔
4634
        }
4635

48✔
4636
        _, ok := coerceTypes(tleft, tright)
52✔
4637
        if !ok {
4✔
4638
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4639
        }
4640

47✔
4641
        if tleft == AnyType {
3✔
4642
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
3✔
UNCOV
4643
                if err != nil {
×
4644
                        return AnyType, err
×
4645
                }
4646
        }
4647

84✔
4648
        if tright == AnyType {
40✔
4649
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
41✔
4650
                if err != nil {
1✔
4651
                        return AnyType, err
1✔
4652
                }
4653
        }
43✔
4654
        return BooleanType, nil
4655
}
4656

48✔
4657
func coerceTypes(t1, t2 SQLValueType) (SQLValueType, bool) {
48✔
UNCOV
4658
        switch {
×
4659
        case t1 == t2:
×
4660
                return t1, true
3✔
4661
        case t1 == AnyType:
3✔
4662
                return t2, true
40✔
4663
        case t2 == AnyType:
40✔
4664
                return t1, true
4665
        case (t1 == IntegerType && t2 == Float64Type) ||
1✔
4666
                (t1 == Float64Type && t2 == IntegerType):
1✔
4667
                return Float64Type, true
4668
        }
4✔
4669
        return "", false
4670
}
4671

41✔
4672
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
42✔
4673
        if t != BooleanType {
1✔
4674
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4675
        }
4676

40✔
4677
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4678
        return err
4679
}
4680

14,324✔
4681
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
14,324✔
4682
        rlexp, err := bexp.left.substitute(params)
14,324✔
UNCOV
4683
        if err != nil {
×
4684
                return nil, err
×
4685
        }
4686

14,324✔
4687
        rrexp, err := bexp.right.substitute(params)
14,325✔
4688
        if err != nil {
1✔
4689
                return nil, err
1✔
4690
        }
4691

14,323✔
4692
        bexp.left = rlexp
14,323✔
4693
        bexp.right = rrexp
14,323✔
4694

14,323✔
4695
        return bexp, nil
4696
}
4697

13,923✔
4698
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
13,923✔
4699
        vl, err := bexp.left.reduce(tx, row, implicitTable)
13,925✔
4700
        if err != nil {
2✔
4701
                return nil, err
2✔
4702
        }
4703

13,921✔
4704
        vr, err := bexp.right.reduce(tx, row, implicitTable)
13,923✔
4705
        if err != nil {
2✔
4706
                return nil, err
2✔
4707
        }
4708

13,919✔
4709
        r, err := vl.Compare(vr)
13,924✔
4710
        if err != nil {
5✔
4711
                return nil, err
5✔
4712
        }
4713

13,914✔
4714
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
4715
}
4716

12✔
4717
func (bexp *CmpBoolExp) selectors() []Selector {
12✔
4718
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12✔
4719
}
4720

282✔
4721
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
282✔
4722
        return &CmpBoolExp{
282✔
4723
                op:    bexp.op,
282✔
4724
                left:  bexp.left.reduceSelectors(row, implicitTable),
282✔
4725
                right: bexp.right.reduceSelectors(row, implicitTable),
282✔
4726
        }
282✔
4727
}
4728

2✔
4729
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4730
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4731
}
4732

607✔
4733
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1,423✔
4734
        matchingFunc := func(_, right ValueExp) (*ColSelector, ValueExp, bool) {
816✔
4735
                s, isSel := bexp.left.(*ColSelector)
1,214✔
4736
                if isSel && s.col != revCol && bexp.right.isConstant() {
398✔
4737
                        return s, right, true
398✔
4738
                }
418✔
4739
                return nil, nil, false
4740
        }
4741

607✔
4742
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
816✔
4743
        if !ok {
209✔
4744
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
209✔
4745
        }
4746

816✔
4747
        if !ok {
209✔
4748
                return nil
209✔
4749
        }
4750

398✔
4751
        aggFn, t, col := sel.resolve(table.name)
412✔
4752
        if aggFn != "" || t != asTable {
14✔
4753
                return nil
14✔
4754
        }
4755

384✔
4756
        column, err := table.GetColumnByName(col)
385✔
4757
        if err != nil {
1✔
4758
                return err
1✔
4759
        }
4760

383✔
4761
        val, err := c.substitute(params)
442✔
4762
        if errors.Is(err, ErrMissingParameter) {
59✔
4763
                // TODO: not supported when parameters are not provided during query resolution
59✔
4764
                return nil
59✔
4765
        }
324✔
UNCOV
4766
        if err != nil {
×
4767
                return err
×
4768
        }
4769

324✔
4770
        rval, err := val.reduce(nil, nil, table.name)
325✔
4771
        if err != nil {
1✔
4772
                return err
1✔
4773
        }
4774

323✔
4775
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
4776
}
4777

20✔
4778
func (bexp *CmpBoolExp) String() string {
20✔
4779
        opStr := CmpOperatorToString(bexp.op)
20✔
4780
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
20✔
4781
}
4782

4783
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
4784
        currRange, ranged := rangesByColID[colID]
4785
        var newRange *typedValueRange
4786

4787
        switch cmp {
4788
        case EQ:
4789
                {
4790
                        newRange = &typedValueRange{
4791
                                lRange: &typedValueSemiRange{
4792
                                        val:       val,
4793
                                        inclusive: true,
4794
                                },
4795
                                hRange: &typedValueSemiRange{
4796
                                        val:       val,
4797
                                        inclusive: true,
4798
                                },
3✔
4799
                        }
3✔
4800
                }
3✔
UNCOV
4801
        case LT:
×
UNCOV
4802
                {
×
4803
                        newRange = &typedValueRange{
4804
                                hRange: &typedValueSemiRange{
3✔
4805
                                        val: val,
3✔
4806
                                },
3✔
UNCOV
4807
                        }
×
UNCOV
4808
                }
×
4809
        case LE:
3✔
4810
                {
4811
                        newRange = &typedValueRange{
4812
                                hRange: &typedValueSemiRange{
4✔
4813
                                        val:       val,
4✔
UNCOV
4814
                                        inclusive: true,
×
UNCOV
4815
                                },
×
4816
                        }
4✔
4817
                }
4818
        case GT:
4819
                {
18✔
4820
                        newRange = &typedValueRange{
18✔
4821
                                lRange: &typedValueSemiRange{
18✔
UNCOV
4822
                                        val: val,
×
UNCOV
4823
                                },
×
4824
                        }
18✔
4825
                }
18✔
4826
        case GE:
18✔
4827
                {
18✔
4828
                        newRange = &typedValueRange{
4829
                                lRange: &typedValueSemiRange{
4830
                                        val:       val,
12✔
4831
                                        inclusive: true,
12✔
4832
                                },
12✔
4833
                        }
4834
                }
18✔
4835
        case NE:
18✔
4836
                {
18✔
UNCOV
4837
                        return nil
×
UNCOV
4838
                }
×
4839
        }
4840

18✔
UNCOV
4841
        if !ranged {
×
UNCOV
4842
                rangesByColID[colID] = newRange
×
4843
                return nil
4844
        }
18✔
UNCOV
4845

×
UNCOV
4846
        return currRange.refineWith(newRange)
×
4847
}
4848

22✔
4849
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
4✔
4850
        switch {
4✔
UNCOV
4851
        case cmp == 0:
×
UNCOV
4852
                {
×
4853
                        return op == EQ || op == LE || op == GE
4✔
4854
                }
4✔
UNCOV
4855
        case cmp < 0:
×
UNCOV
4856
                {
×
4857
                        return op == NE || op == LT || op == LE
4✔
4858
                }
4859
        case cmp > 0:
4860
                {
18✔
4861
                        return op == NE || op == GT || op == GE
18✔
4862
                }
18✔
4863
        }
18✔
4864
        return false
18✔
4865
}
3✔
4866

3✔
4867
type BinBoolExp struct {
3✔
4868
        op          LogicOperator
3✔
4869
        left, right ValueExp
3✔
4870
}
3✔
4871

3✔
4872
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
3✔
4873
        bexp := &BinBoolExp{
3✔
4874
                op: op,
3✔
4875
        }
3✔
4876

3✔
4877
        bexp.left = lrexp
UNCOV
4878
        bexp.right = rrexp
×
4879

4880
        return bexp
UNCOV
4881
}
×
UNCOV
4882

×
UNCOV
4883
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
×
UNCOV
4884
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
×
UNCOV
4885
        if err != nil {
×
4886
                return AnyType, err
×
4887
        }
4888

1✔
4889
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
1✔
4890
        if err != nil {
1✔
4891
                return AnyType, err
UNCOV
4892
        }
×
UNCOV
4893

×
UNCOV
4894
        return BooleanType, nil
×
4895
}
4896

6✔
4897
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
4898
        if t != BooleanType {
6✔
4899
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
4900
        }
323✔
4901

323✔
4902
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
323✔
4903
        if err != nil {
323✔
4904
                return err
323✔
4905
        }
250✔
4906

500✔
4907
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
250✔
4908
        if err != nil {
250✔
4909
                return err
250✔
4910
        }
250✔
4911

250✔
4912
        return nil
250✔
4913
}
250✔
4914

250✔
4915
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
250✔
4916
        rlexp, err := bexp.left.substitute(params)
250✔
4917
        if err != nil {
250✔
4918
                return nil, err
13✔
4919
        }
26✔
4920

13✔
4921
        rrexp, err := bexp.right.substitute(params)
13✔
4922
        if err != nil {
13✔
4923
                return nil, err
13✔
4924
        }
13✔
4925

13✔
4926
        bexp.left = rlexp
12✔
4927
        bexp.right = rrexp
24✔
4928

12✔
4929
        return bexp, nil
12✔
4930
}
12✔
4931

12✔
4932
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
12✔
4933
        vl, err := bexp.left.reduce(tx, row, implicitTable)
12✔
4934
        if err != nil {
12✔
4935
                return nil, err
18✔
4936
        }
36✔
4937

18✔
4938
        bl, isBool := vl.(*Bool)
18✔
4939
        if !isBool {
18✔
4940
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
18✔
4941
        }
18✔
4942

18✔
4943
        // short-circuit evaluation
18✔
4944
        if (bl.val && bexp.op == Or) || (!bl.val && bexp.op == And) {
36✔
4945
                return &Bool{val: bl.val}, nil
18✔
4946
        }
18✔
4947

18✔
4948
        vr, err := bexp.right.reduce(tx, row, implicitTable)
18✔
4949
        if err != nil {
18✔
4950
                return nil, err
18✔
4951
        }
18✔
4952

12✔
4953
        br, isBool := vr.(*Bool)
24✔
4954
        if !isBool {
12✔
4955
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
12✔
4956
        }
4957

4958
        switch bexp.op {
617✔
4959
        case And:
306✔
4960
                {
306✔
4961
                        return &Bool{val: bl.val && br.val}, nil
306✔
4962
                }
4963
        case Or:
5✔
4964
                {
4965
                        return &Bool{val: bl.val || br.val}, nil
4966
                }
13,914✔
4967
        }
13,914✔
4968

1,169✔
4969
        return nil, ErrUnexpected
2,338✔
4970
}
1,169✔
4971

1,169✔
4972
func (bexp *BinBoolExp) selectors() []Selector {
6,489✔
4973
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12,978✔
4974
}
6,489✔
4975

6,489✔
4976
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
6,256✔
4977
        return &BinBoolExp{
12,512✔
4978
                op:    bexp.op,
6,256✔
4979
                left:  bexp.left.reduceSelectors(row, implicitTable),
6,256✔
4980
                right: bexp.right.reduceSelectors(row, implicitTable),
UNCOV
4981
        }
×
4982
}
4983

4984
func (bexp *BinBoolExp) isConstant() bool {
4985
        return bexp.left.isConstant() && bexp.right.isConstant()
4986
}
4987

4988
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
4989
        if bexp.op == And {
18✔
4990
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
18✔
4991
                if err != nil {
18✔
4992
                        return err
18✔
4993
                }
18✔
4994

18✔
4995
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
18✔
4996
        }
18✔
4997

18✔
4998
        lRanges := make(map[uint32]*typedValueRange)
18✔
4999
        rRanges := make(map[uint32]*typedValueRange)
5000

20✔
5001
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
20✔
5002
        if err != nil {
20✔
5003
                return err
×
5004
        }
×
5005

5006
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
20✔
5007
        if err != nil {
22✔
5008
                return err
2✔
5009
        }
2✔
5010

5011
        for colID, lr := range lRanges {
18✔
5012
                rr, ok := rRanges[colID]
5013
                if !ok {
5014
                        continue
22✔
5015
                }
25✔
5016

3✔
5017
                err = lr.extendWith(rr)
3✔
5018
                if err != nil {
5019
                        return err
19✔
5020
                }
20✔
5021

1✔
5022
                rangesByColID[colID] = lr
1✔
5023
        }
5024

18✔
5025
        return nil
18✔
UNCOV
5026
}
×
UNCOV
5027

×
5028
func (bexp *BinBoolExp) String() string {
5029
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
18✔
5030
}
5031

5032
type ExistsBoolExp struct {
576✔
5033
        q DataSource
576✔
5034
}
576✔
UNCOV
5035

×
UNCOV
5036
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
×
5037
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
5038
}
576✔
5039

576✔
UNCOV
5040
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
UNCOV
5041
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
×
5042
}
5043

576✔
5044
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
576✔
5045
        return bexp, nil
576✔
5046
}
576✔
5047

5048
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5049
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
539✔
5050
}
539✔
5051

540✔
5052
func (bexp *ExistsBoolExp) selectors() []Selector {
1✔
5053
        return nil
1✔
5054
}
5055

538✔
5056
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
538✔
UNCOV
5057
        return bexp
×
UNCOV
5058
}
×
5059

5060
func (bexp *ExistsBoolExp) isConstant() bool {
5061
        return false
714✔
5062
}
176✔
5063

176✔
5064
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
5065
        return nil
362✔
5066
}
363✔
5067

1✔
5068
func (bexp *ExistsBoolExp) String() string {
1✔
5069
        return ""
5070
}
361✔
5071

361✔
UNCOV
5072
type InSubQueryExp struct {
×
UNCOV
5073
        val   ValueExp
×
5074
        notIn bool
5075
        q     *SelectStmt
361✔
5076
}
340✔
5077

680✔
5078
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
340✔
5079
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
340✔
5080
}
21✔
5081

42✔
5082
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
21✔
5083
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
21✔
5084
}
5085

UNCOV
5086
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
×
5087
        return bexp, nil
5088
}
5089

2✔
5090
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5091
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
5092
}
5093

15✔
5094
func (bexp *InSubQueryExp) selectors() []Selector {
15✔
5095
        return bexp.val.selectors()
15✔
5096
}
15✔
5097

15✔
5098
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
5099
        return bexp
15✔
5100
}
5101

1✔
5102
func (bexp *InSubQueryExp) isConstant() bool {
1✔
5103
        return false
1✔
5104
}
5105

153✔
5106
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
292✔
5107
        return nil
139✔
5108
}
139✔
UNCOV
5109

×
5110
func (bexp *InSubQueryExp) String() string {
×
5111
        return ""
5112
}
139✔
5113

5114
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
5115
type InListExp struct {
14✔
5116
        val    ValueExp
14✔
5117
        notIn  bool
14✔
5118
        values []ValueExp
14✔
5119
}
14✔
UNCOV
5120

×
UNCOV
5121
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
×
5122
        t, err := bexp.val.inferType(cols, params, implicitTable)
5123
        if err != nil {
14✔
5124
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
14✔
UNCOV
5125
        }
×
UNCOV
5126

×
5127
        for _, v := range bexp.values {
5128
                err = v.requiresType(t, cols, params, implicitTable)
21✔
5129
                if err != nil {
7✔
5130
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
9✔
5131
                }
2✔
5132
        }
5133

5134
        return BooleanType, nil
5✔
5135
}
5✔
UNCOV
5136

×
UNCOV
5137
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
5138
        _, err := bexp.inferType(cols, params, implicitTable)
5139
        if err != nil {
5✔
5140
                return err
5141
        }
5142

14✔
5143
        if t != BooleanType {
5144
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
5145
        }
31✔
5146

31✔
5147
        return nil
31✔
5148
}
5149

5150
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
5151
        val, err := bexp.val.substitute(params)
5152
        if err != nil {
5153
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5154
        }
1✔
5155

1✔
5156
        values := make([]ValueExp, len(bexp.values))
5157

1✔
5158
        for i, val := range bexp.values {
1✔
5159
                values[i], err = val.substitute(params)
1✔
5160
                if err != nil {
5161
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5162
                }
1✔
5163
        }
1✔
5164

5165
        return &InListExp{
2✔
5166
                val:    val,
2✔
5167
                notIn:  bexp.notIn,
2✔
5168
                values: values,
5169
        }, nil
1✔
5170
}
1✔
5171

1✔
5172
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5173
        rval, err := bexp.val.reduce(tx, row, implicitTable)
1✔
5174
        if err != nil {
1✔
5175
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5176
        }
5177

2✔
5178
        var found bool
2✔
5179

2✔
5180
        for _, v := range bexp.values {
5181
                rv, err := v.reduce(tx, row, implicitTable)
1✔
5182
                if err != nil {
1✔
5183
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5184
                }
UNCOV
5185

×
UNCOV
5186
                r, err := rval.Compare(rv)
×
UNCOV
5187
                if err != nil {
×
5188
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
5189
                }
5190

5191
                if r == 0 {
5192
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
5193
                        found = found || true
5194
                }
5195
        }
1✔
5196

1✔
5197
        return &Bool{val: found != bexp.notIn}, nil
1✔
5198
}
5199

1✔
5200
func (bexp *InListExp) selectors() []Selector {
1✔
5201
        selectors := make([]Selector, 0, len(bexp.values))
1✔
5202
        for _, v := range bexp.values {
5203
                selectors = append(selectors, v.selectors()...)
1✔
5204
        }
1✔
5205
        return append(bexp.val.selectors(), selectors...)
1✔
5206
}
5207

2✔
5208
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
2✔
5209
        values := make([]ValueExp, len(bexp.values))
2✔
5210

5211
        for i, val := range bexp.values {
1✔
5212
                values[i] = val.reduceSelectors(row, implicitTable)
1✔
5213
        }
1✔
5214

5215
        return &InListExp{
1✔
5216
                val:    bexp.val.reduceSelectors(row, implicitTable),
1✔
5217
                values: values,
1✔
5218
        }
5219
}
1✔
5220

1✔
5221
func (bexp *InListExp) isConstant() bool {
1✔
5222
        return false
5223
}
1✔
5224

1✔
5225
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5226
        // TODO: may be determiined by smallest and bigggest value in the list
UNCOV
5227
        return nil
×
UNCOV
5228
}
×
UNCOV
5229

×
5230
func (bexp *InListExp) String() string {
5231
        values := make([]string, len(bexp.values))
5232
        for i, exp := range bexp.values {
5233
                values[i] = exp.String()
5234
        }
5235
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
5236
}
5237

5238
type FnDataSourceStmt struct {
6✔
5239
        fnCall *FnCall
6✔
5240
        as     string
8✔
5241
}
2✔
5242

2✔
5243
func (stmt *FnDataSourceStmt) readOnly() bool {
5244
        return true
12✔
5245
}
8✔
5246

9✔
5247
func (stmt *FnDataSourceStmt) requiredPrivileges() []SQLPrivilege {
1✔
5248
        return nil
1✔
5249
}
5250

5251
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
5252
        return tx, nil
5253
}
5254

2✔
5255
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
2✔
5256
        return nil
3✔
5257
}
1✔
5258

1✔
5259
func (stmt *FnDataSourceStmt) Alias() string {
5260
        if stmt.as != "" {
1✔
UNCOV
5261
                return stmt.as
×
UNCOV
5262
        }
×
5263

5264
        switch strings.ToUpper(stmt.fnCall.fn) {
1✔
5265
        case DatabasesFnCall:
5266
                {
5267
                        return "databases"
115✔
5268
                }
115✔
5269
        case TablesFnCall:
115✔
UNCOV
5270
                {
×
UNCOV
5271
                        return "tables"
×
5272
                }
5273
        case TableFnCall:
115✔
5274
                {
115✔
5275
                        return "table"
245✔
5276
                }
130✔
5277
        case UsersFnCall:
130✔
UNCOV
5278
                {
×
UNCOV
5279
                        return "users"
×
5280
                }
5281
        case ColumnsFnCall:
5282
                {
115✔
5283
                        return "columns"
115✔
5284
                }
115✔
5285
        case IndexesFnCall:
115✔
5286
                {
115✔
5287
                        return "indexes"
5288
                }
5289
        case GrantsFnCall:
115✔
5290
                return "grants"
115✔
5291
        }
116✔
5292

1✔
5293
        // not reachable
1✔
5294
        return ""
5295
}
114✔
5296

114✔
5297
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
241✔
5298
        if stmt.fnCall == nil {
127✔
5299
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
128✔
5300
        }
1✔
5301

1✔
5302
        switch strings.ToUpper(stmt.fnCall.fn) {
5303
        case DatabasesFnCall:
126✔
5304
                {
127✔
5305
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
1✔
5306
                }
1✔
5307
        case TablesFnCall:
5308
                {
140✔
5309
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
15✔
5310
                }
15✔
5311
        case TableFnCall:
15✔
5312
                {
5313
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
5314
                }
112✔
5315
        case UsersFnCall:
5316
                {
5317
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
1✔
5318
                }
1✔
5319
        case ColumnsFnCall:
4✔
5320
                {
3✔
5321
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
5322
                }
1✔
5323
        case IndexesFnCall:
5324
                {
5325
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
10✔
5326
                }
10✔
5327
        case GrantsFnCall:
10✔
5328
                {
20✔
5329
                        return stmt.resolveListGrants(ctx, tx, params, scanSpecs)
10✔
5330
                }
10✔
5331
        }
5332

10✔
5333
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
10✔
5334
}
10✔
5335

10✔
5336
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5337
        if len(stmt.fnCall.params) > 0 {
5338
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
1✔
5339
        }
1✔
5340

1✔
5341
        cols := make([]ColDescriptor, 1)
5342
        cols[0] = ColDescriptor{
21✔
5343
                Column: "name",
21✔
5344
                Type:   VarcharType,
21✔
5345
        }
21✔
5346

5347
        var dbs []string
1✔
5348

1✔
5349
        if tx.engine.multidbHandler == nil {
5✔
5350
                return nil, ErrUnspecifiedMultiDBHandler
4✔
5351
        } else {
4✔
5352
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
1✔
5353
                if err != nil {
5354
                        return nil, err
5355
                }
5356
        }
5357

5358
        values := make([][]ValueExp, len(dbs))
5359

5360
        for i, db := range dbs {
1✔
5361
                values[i] = []ValueExp{&Varchar{val: db}}
1✔
5362
        }
1✔
5363

5364
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
1✔
5365
}
1✔
5366

1✔
5367
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
UNCOV
5368
        if len(stmt.fnCall.params) > 0 {
×
5369
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
5370
        }
×
5371

5372
        cols := make([]ColDescriptor, 1)
1✔
5373
        cols[0] = ColDescriptor{
1✔
5374
                Column: "name",
1✔
5375
                Type:   VarcharType,
5376
        }
24✔
5377

26✔
5378
        tables := tx.catalog.GetTables()
2✔
5379

2✔
5380
        values := make([][]ValueExp, len(tables))
5381

22✔
5382
        for i, t := range tables {
3✔
5383
                values[i] = []ValueExp{&Varchar{val: t.name}}
6✔
5384
        }
3✔
5385

3✔
5386
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5387
}
10✔
5388

5✔
5389
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5390
        cols := []ColDescriptor{
×
5391
                {
×
5392
                        Column: "column_name",
×
5393
                        Type:   VarcharType,
×
5394
                },
7✔
5395
                {
14✔
5396
                        Column: "type_name",
7✔
5397
                        Type:   VarcharType,
7✔
5398
                },
3✔
5399
                {
6✔
5400
                        Column: "is_nullable",
3✔
5401
                        Type:   BooleanType,
3✔
5402
                },
2✔
5403
                {
4✔
5404
                        Column: "is_indexed",
2✔
5405
                        Type:   VarcharType,
2✔
5406
                },
2✔
5407
                {
2✔
5408
                        Column: "is_auto_increment",
5409
                        Type:   BooleanType,
5410
                },
5411
                {
×
5412
                        Column: "is_unique",
5413
                        Type:   BooleanType,
5414
                },
25✔
5415
        }
25✔
5416

×
5417
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
5418
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
5419
        if err != nil {
25✔
5420
                return nil, err
5✔
5421
        }
10✔
5422

5✔
5423
        values := make([][]ValueExp, len(table.cols))
5✔
5424

5✔
5425
        for i, c := range table.cols {
10✔
5426
                index := "NO"
5✔
5427

5✔
5428
                indexed, err := table.IsIndexed(c.Name())
×
5429
                if err != nil {
×
5430
                        return nil, err
×
5431
                }
×
5432
                if indexed {
7✔
5433
                        index = "YES"
14✔
5434
                }
7✔
5435

7✔
5436
                if table.PrimaryIndex().IncludesCol(c.ID()) {
3✔
5437
                        index = "PRIMARY KEY"
6✔
5438
                }
3✔
5439

3✔
5440
                var unique bool
3✔
5441
                for _, index := range table.GetIndexesByColID(c.ID()) {
6✔
5442
                        if index.IsUnique() && len(index.Cols()) == 1 {
3✔
5443
                                unique = true
3✔
5444
                                break
2✔
5445
                        }
4✔
5446
                }
2✔
5447

2✔
5448
                var maxLen string
5449

5450
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
5451
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
5452
                }
5453

5✔
5454
                values[i] = []ValueExp{
5✔
5455
                        &Varchar{val: c.colName},
×
5456
                        &Varchar{val: c.Type() + maxLen},
×
5457
                        &Bool{val: c.IsNullable()},
5458
                        &Varchar{val: index},
5✔
5459
                        &Bool{val: c.IsAutoIncremental()},
5✔
5460
                        &Bool{val: unique},
5✔
5461
                }
5✔
5462
        }
5✔
5463

5✔
5464
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5465
}
5✔
5466

6✔
5467
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
1✔
5468
        if len(stmt.fnCall.params) > 0 {
5✔
5469
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
4✔
5470
        }
4✔
UNCOV
5471

×
UNCOV
5472
        cols := []ColDescriptor{
×
5473
                {
5474
                        Column: "name",
5475
                        Type:   VarcharType,
4✔
5476
                },
4✔
5477
                {
12✔
5478
                        Column: "permission",
8✔
5479
                        Type:   VarcharType,
8✔
5480
                },
5481
        }
4✔
5482

5483
        users, err := tx.ListUsers(ctx)
5484
        if err != nil {
5✔
5485
                return nil, err
5✔
5486
        }
×
UNCOV
5487

×
5488
        values := make([][]ValueExp, len(users))
5489
        for i, user := range users {
5✔
5490
                perm := user.Permission()
5✔
5491

5✔
5492
                values[i] = []ValueExp{
5✔
5493
                        &Varchar{val: user.Username()},
5✔
5494
                        &Varchar{val: perm},
5✔
5495
                }
5✔
5496
        }
5✔
5497
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5498
}
5✔
5499

14✔
5500
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
9✔
5501
        if len(stmt.fnCall.params) != 1 {
9✔
5502
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
5503
        }
5✔
5504

5505
        cols := []ColDescriptor{
UNCOV
5506
                {
×
UNCOV
5507
                        Column: "table",
×
UNCOV
5508
                        Type:   VarcharType,
×
UNCOV
5509
                },
×
UNCOV
5510
                {
×
UNCOV
5511
                        Column: "name",
×
UNCOV
5512
                        Type:   VarcharType,
×
UNCOV
5513
                },
×
UNCOV
5514
                {
×
UNCOV
5515
                        Column: "type",
×
UNCOV
5516
                        Type:   VarcharType,
×
UNCOV
5517
                },
×
UNCOV
5518
                {
×
UNCOV
5519
                        Column: "max_length",
×
UNCOV
5520
                        Type:   IntegerType,
×
UNCOV
5521
                },
×
UNCOV
5522
                {
×
UNCOV
5523
                        Column: "nullable",
×
UNCOV
5524
                        Type:   BooleanType,
×
UNCOV
5525
                },
×
UNCOV
5526
                {
×
UNCOV
5527
                        Column: "auto_increment",
×
UNCOV
5528
                        Type:   BooleanType,
×
UNCOV
5529
                },
×
UNCOV
5530
                {
×
UNCOV
5531
                        Column: "indexed",
×
UNCOV
5532
                        Type:   BooleanType,
×
UNCOV
5533
                },
×
UNCOV
5534
                {
×
UNCOV
5535
                        Column: "primary",
×
UNCOV
5536
                        Type:   BooleanType,
×
UNCOV
5537
                },
×
UNCOV
5538
                {
×
5539
                        Column: "unique",
UNCOV
5540
                        Type:   BooleanType,
×
UNCOV
5541
                },
×
UNCOV
5542
        }
×
UNCOV
5543

×
UNCOV
5544
        val, err := stmt.fnCall.params[0].substitute(params)
×
UNCOV
5545
        if err != nil {
×
5546
                return nil, err
×
5547
        }
×
UNCOV
5548

×
UNCOV
5549
        tableName, err := val.reduce(tx, nil, "")
×
UNCOV
5550
        if err != nil {
×
5551
                return nil, err
×
5552
        }
UNCOV
5553

×
UNCOV
5554
        if tableName.Type() != VarcharType {
×
5555
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5556
        }
UNCOV
5557

×
UNCOV
5558
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
UNCOV
5559
        if err != nil {
×
5560
                return nil, err
×
5561
        }
×
5562

5563
        values := make([][]ValueExp, len(table.cols))
5564

UNCOV
5565
        for i, c := range table.cols {
×
UNCOV
5566
                indexed, err := table.IsIndexed(c.Name())
×
UNCOV
5567
                if err != nil {
×
5568
                        return nil, err
×
5569
                }
×
5570

UNCOV
5571
                var unique bool
×
UNCOV
5572
                for _, index := range table.indexesByColID[c.id] {
×
UNCOV
5573
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
UNCOV
5574
                                unique = true
×
UNCOV
5575
                                break
×
UNCOV
5576
                        }
×
UNCOV
5577
                }
×
UNCOV
5578

×
5579
                values[i] = []ValueExp{
5580
                        &Varchar{val: table.name},
UNCOV
5581
                        &Varchar{val: c.colName},
×
5582
                        &Varchar{val: c.colType},
5583
                        &Integer{val: int64(c.MaxLen())},
5584
                        &Bool{val: c.IsNullable()},
7✔
5585
                        &Bool{val: c.autoIncrement},
7✔
UNCOV
5586
                        &Bool{val: indexed},
×
UNCOV
5587
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
×
5588
                        &Bool{val: unique},
5589
                }
7✔
5590
        }
7✔
5591

7✔
5592
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
7✔
5593
}
7✔
5594

7✔
5595
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
7✔
5596
        if len(stmt.fnCall.params) != 1 {
7✔
5597
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
7✔
5598
        }
7✔
5599

7✔
5600
        cols := []ColDescriptor{
7✔
5601
                {
7✔
UNCOV
5602
                        Column: "table",
×
UNCOV
5603
                        Type:   VarcharType,
×
5604
                },
5605
                {
7✔
5606
                        Column: "name",
23✔
5607
                        Type:   VarcharType,
16✔
5608
                },
16✔
5609
                {
16✔
5610
                        Column: "unique",
16✔
5611
                        Type:   BooleanType,
16✔
5612
                },
16✔
5613
                {
16✔
5614
                        Column: "primary",
7✔
5615
                        Type:   BooleanType,
5616
                },
5617
        }
3✔
5618

3✔
UNCOV
5619
        val, err := stmt.fnCall.params[0].substitute(params)
×
UNCOV
5620
        if err != nil {
×
5621
                return nil, err
5622
        }
3✔
5623

3✔
5624
        tableName, err := val.reduce(tx, nil, "")
3✔
5625
        if err != nil {
3✔
5626
                return nil, err
3✔
5627
        }
3✔
5628

3✔
5629
        if tableName.Type() != VarcharType {
3✔
5630
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
3✔
5631
        }
3✔
5632

3✔
5633
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5634
        if err != nil {
3✔
5635
                return nil, err
3✔
5636
        }
3✔
5637

3✔
5638
        values := make([][]ValueExp, len(table.indexes))
3✔
5639

3✔
5640
        for i, index := range table.indexes {
3✔
5641
                values[i] = []ValueExp{
3✔
5642
                        &Varchar{val: table.name},
3✔
5643
                        &Varchar{val: index.Name()},
3✔
5644
                        &Bool{val: index.unique},
3✔
5645
                        &Bool{val: index.IsPrimary()},
3✔
5646
                }
3✔
5647
        }
3✔
5648

3✔
5649
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5650
}
3✔
5651

3✔
5652
func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5653
        if len(stmt.fnCall.params) > 1 {
3✔
5654
                return nil, fmt.Errorf("%w: function '%s' expect at most one parameter of type %s", ErrIllegalArguments, GrantsFnCall, VarcharType)
3✔
5655
        }
3✔
5656

3✔
5657
        var username string
3✔
5658
        if len(stmt.fnCall.params) == 1 {
3✔
5659
                val, err := stmt.fnCall.params[0].substitute(params)
3✔
5660
                if err != nil {
3✔
5661
                        return nil, err
3✔
5662
                }
3✔
UNCOV
5663

×
UNCOV
5664
                userVal, err := val.reduce(tx, nil, "")
×
5665
                if err != nil {
5666
                        return nil, err
3✔
5667
                }
3✔
UNCOV
5668

×
UNCOV
5669
                if userVal.Type() != VarcharType {
×
5670
                        return nil, fmt.Errorf("%w: expected '%s' for username but type '%s' given instead", ErrIllegalArguments, VarcharType, userVal.Type())
5671
                }
3✔
UNCOV
5672
                username, _ = userVal.RawValue().(string)
×
UNCOV
5673
        }
×
5674

5675
        cols := []ColDescriptor{
3✔
5676
                {
3✔
UNCOV
5677
                        Column: "user",
×
UNCOV
5678
                        Type:   VarcharType,
×
5679
                },
5680
                {
3✔
5681
                        Column: "privilege",
3✔
5682
                        Type:   VarcharType,
11✔
5683
                },
8✔
5684
        }
8✔
UNCOV
5685

×
UNCOV
5686
        var err error
×
5687
        var users []User
5688

8✔
5689
        if tx.engine.multidbHandler == nil {
16✔
5690
                return nil, ErrUnspecifiedMultiDBHandler
11✔
5691
        } else {
3✔
5692
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
3✔
5693
                if err != nil {
5694
                        return nil, err
5695
                }
5696
        }
8✔
5697

8✔
5698
        values := make([][]ValueExp, 0, len(users))
8✔
5699

8✔
5700
        for _, user := range users {
8✔
5701
                if username == "" || user.Username() == username {
8✔
5702
                        for _, p := range user.SQLPrivileges() {
8✔
5703
                                values = append(values, []ValueExp{
8✔
5704
                                        &Varchar{val: user.Username()},
8✔
5705
                                        &Varchar{val: string(p)},
8✔
5706
                                })
8✔
5707
                        }
5708
                }
5709
        }
3✔
5710

5711
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5712
}
3✔
5713

3✔
UNCOV
5714
// DropTableStmt represents a statement to delete a table.
×
UNCOV
5715
type DropTableStmt struct {
×
5716
        table string
5717
}
3✔
5718

3✔
5719
func NewDropTableStmt(table string) *DropTableStmt {
3✔
5720
        return &DropTableStmt{table: table}
3✔
5721
}
3✔
5722

3✔
5723
func (stmt *DropTableStmt) readOnly() bool {
3✔
5724
        return false
3✔
5725
}
3✔
5726

3✔
5727
func (stmt *DropTableStmt) requiredPrivileges() []SQLPrivilege {
3✔
5728
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
5729
}
3✔
5730

3✔
5731
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
5732
        return nil
3✔
5733
}
3✔
5734

3✔
5735
/*
3✔
5736
Exec executes the delete table statement.
3✔
5737
It the table exists, if not it does nothing.
3✔
UNCOV
5738
If the table exists, it deletes all the indexes and the table itself.
×
UNCOV
5739
Note that this is a soft delete of the index and table key,
×
5740
the data is not deleted, but the metadata is updated.
5741
*/
3✔
5742
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
UNCOV
5743
        if !tx.catalog.ExistTable(stmt.table) {
×
UNCOV
5744
                return nil, ErrTableDoesNotExist
×
5745
        }
5746

3✔
UNCOV
5747
        table, err := tx.catalog.GetTableByName(stmt.table)
×
UNCOV
5748
        if err != nil {
×
5749
                return nil, err
5750
        }
3✔
5751

3✔
UNCOV
5752
        // delete table
×
UNCOV
5753
        mappedKey := MapKey(
×
5754
                tx.sqlPrefix(),
5755
                catalogTablePrefix,
3✔
5756
                EncodeID(DatabaseID),
3✔
5757
                EncodeID(table.id),
10✔
5758
        )
7✔
5759
        err = tx.delete(ctx, mappedKey)
7✔
5760
        if err != nil {
7✔
5761
                return nil, err
7✔
5762
        }
7✔
5763

7✔
5764
        // delete columns
7✔
5765
        cols := table.ColumnsByID()
5766
        for _, col := range cols {
3✔
5767
                mappedKey := MapKey(
5768
                        tx.sqlPrefix(),
5769
                        catalogColumnPrefix,
2✔
5770
                        EncodeID(DatabaseID),
2✔
UNCOV
5771
                        EncodeID(col.table.id),
×
UNCOV
5772
                        EncodeID(col.id),
×
5773
                        []byte(col.colType),
5774
                )
2✔
5775
                err = tx.delete(ctx, mappedKey)
3✔
5776
                if err != nil {
1✔
5777
                        return nil, err
1✔
5778
                }
×
UNCOV
5779
        }
×
5780

5781
        // delete checks
1✔
5782
        for name := range table.checkConstraints {
1✔
5783
                key := MapKey(
×
5784
                        tx.sqlPrefix(),
×
5785
                        catalogCheckPrefix,
5786
                        EncodeID(DatabaseID),
1✔
5787
                        EncodeID(table.id),
×
5788
                        []byte(name),
×
5789
                )
1✔
5790

5791
                if err := tx.delete(ctx, key); err != nil {
5792
                        return nil, err
2✔
5793
                }
2✔
5794
        }
2✔
5795

2✔
5796
        // delete indexes
2✔
5797
        for _, index := range table.indexes {
2✔
5798
                mappedKey := MapKey(
2✔
5799
                        tx.sqlPrefix(),
2✔
5800
                        catalogIndexPrefix,
2✔
5801
                        EncodeID(DatabaseID),
2✔
5802
                        EncodeID(table.id),
2✔
5803
                        EncodeID(index.id),
2✔
5804
                )
2✔
5805
                err = tx.delete(ctx, mappedKey)
2✔
5806
                if err != nil {
2✔
5807
                        return nil, err
×
5808
                }
2✔
5809

2✔
5810
                indexKey := MapKey(
2✔
UNCOV
5811
                        tx.sqlPrefix(),
×
UNCOV
5812
                        MappedPrefix,
×
5813
                        EncodeID(table.id),
5814
                        EncodeID(index.id),
5815
                )
2✔
5816
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
2✔
5817
                        return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
5818
                })
4✔
5819
                if err != nil {
6✔
5820
                        return nil, err
4✔
5821
                }
4✔
5822
        }
4✔
5823

4✔
5824
        err = tx.catalog.deleteTable(table)
4✔
5825
        if err != nil {
5826
                return nil, err
5827
        }
5828

2✔
5829
        tx.mutatedCatalog = true
5830

5831
        return tx, nil
5832
}
5833

5834
// DropIndexStmt represents a statement to delete a table.
5835
type DropIndexStmt struct {
5836
        table string
6✔
5837
        cols  []string
6✔
5838
}
6✔
5839

5840
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
1✔
5841
        return &DropIndexStmt{table: table, cols: cols}
1✔
5842
}
1✔
5843

5844
func (stmt *DropIndexStmt) readOnly() bool {
1✔
5845
        return false
1✔
5846
}
1✔
5847

5848
func (stmt *DropIndexStmt) requiredPrivileges() []SQLPrivilege {
1✔
5849
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5850
}
1✔
5851

5852
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
5853
        return nil
5854
}
5855

5856
/*
5857
Exec executes the delete index statement.
5858
If the index exists, it deletes it. Note that this is a soft delete of the index
5859
the data is not deleted, but the metadata is updated.
7✔
5860
*/
8✔
5861
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
5862
        if !tx.catalog.ExistTable(stmt.table) {
1✔
5863
                return nil, ErrTableDoesNotExist
5864
        }
6✔
5865

6✔
UNCOV
5866
        table, err := tx.catalog.GetTableByName(stmt.table)
×
UNCOV
5867
        if err != nil {
×
5868
                return nil, err
5869
        }
5870

6✔
5871
        cols := make([]*Column, len(stmt.cols))
6✔
5872

6✔
5873
        for i, colName := range stmt.cols {
6✔
5874
                col, err := table.GetColumnByName(colName)
6✔
5875
                if err != nil {
6✔
5876
                        return nil, err
6✔
5877
                }
6✔
UNCOV
5878

×
UNCOV
5879
                cols[i] = col
×
5880
        }
5881

5882
        index, err := table.GetIndexByName(indexName(table.name, cols))
6✔
5883
        if err != nil {
26✔
5884
                return nil, err
20✔
5885
        }
20✔
5886

20✔
5887
        // delete index
20✔
5888
        mappedKey := MapKey(
20✔
5889
                tx.sqlPrefix(),
20✔
5890
                catalogIndexPrefix,
20✔
5891
                EncodeID(DatabaseID),
20✔
5892
                EncodeID(table.id),
20✔
5893
                EncodeID(index.id),
20✔
UNCOV
5894
        )
×
UNCOV
5895
        err = tx.delete(ctx, mappedKey)
×
5896
        if err != nil {
5897
                return nil, err
5898
        }
5899

6✔
UNCOV
5900
        indexKey := MapKey(
×
UNCOV
5901
                tx.sqlPrefix(),
×
UNCOV
5902
                MappedPrefix,
×
UNCOV
5903
                EncodeID(table.id),
×
UNCOV
5904
                EncodeID(index.id),
×
UNCOV
5905
        )
×
UNCOV
5906

×
UNCOV
5907
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
×
UNCOV
5908
                return sqlTx.engine.store.DeleteIndex(indexKey)
×
UNCOV
5909
        })
×
UNCOV
5910
        if err != nil {
×
5911
                return nil, err
5912
        }
5913

5914
        err = table.deleteIndex(index)
13✔
5915
        if err != nil {
7✔
5916
                return nil, err
7✔
5917
        }
7✔
5918

7✔
5919
        tx.mutatedCatalog = true
7✔
5920

7✔
5921
        return tx, nil
7✔
5922
}
7✔
5923

7✔
UNCOV
5924
type SQLPrivilege string
×
UNCOV
5925

×
5926
const (
5927
        SQLPrivilegeSelect SQLPrivilege = "SELECT"
7✔
5928
        SQLPrivilegeCreate SQLPrivilege = "CREATE"
7✔
5929
        SQLPrivilegeInsert SQLPrivilege = "INSERT"
7✔
5930
        SQLPrivilegeUpdate SQLPrivilege = "UPDATE"
7✔
5931
        SQLPrivilegeDelete SQLPrivilege = "DELETE"
7✔
5932
        SQLPrivilegeDrop   SQLPrivilege = "DROP"
7✔
5933
        SQLPrivilegeAlter  SQLPrivilege = "ALTER"
14✔
5934
)
7✔
5935

7✔
5936
var allPrivileges = []SQLPrivilege{
7✔
UNCOV
5937
        SQLPrivilegeSelect,
×
UNCOV
5938
        SQLPrivilegeCreate,
×
5939
        SQLPrivilegeInsert,
5940
        SQLPrivilegeUpdate,
5941
        SQLPrivilegeDelete,
6✔
5942
        SQLPrivilegeDrop,
6✔
UNCOV
5943
        SQLPrivilegeAlter,
×
UNCOV
5944
}
×
5945

5946
func DefaultSQLPrivilegesForPermission(p Permission) []SQLPrivilege {
6✔
5947
        switch p {
6✔
5948
        case PermissionSysAdmin, PermissionAdmin, PermissionReadWrite:
6✔
5949
                return allPrivileges
5950
        case PermissionReadOnly:
5951
                return []SQLPrivilege{SQLPrivilegeSelect}
5952
        }
5953
        return nil
5954
}
5955

5956
type AlterPrivilegesStmt struct {
5957
        database   string
4✔
5958
        user       string
4✔
5959
        privileges []SQLPrivilege
4✔
5960
        isGrant    bool
5961
}
1✔
5962

1✔
5963
func (stmt *AlterPrivilegesStmt) readOnly() bool {
1✔
5964
        return false
5965
}
1✔
5966

1✔
5967
func (stmt *AlterPrivilegesStmt) requiredPrivileges() []SQLPrivilege {
1✔
5968
        return nil
5969
}
1✔
5970

1✔
5971
func (stmt *AlterPrivilegesStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
5972
        if tx.IsExplicitCloseRequired() {
5973
                return nil, fmt.Errorf("%w: user privileges modification can not be done within a transaction", ErrNonTransactionalStmt)
5974
        }
5975

5976
        if tx.engine.multidbHandler == nil {
5977
                return nil, ErrUnspecifiedMultiDBHandler
5978
        }
6✔
5979

7✔
5980
        var err error
1✔
5981
        if stmt.isGrant {
1✔
5982
                err = tx.engine.multidbHandler.GrantSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
5983
        } else {
5✔
5984
                err = tx.engine.multidbHandler.RevokeSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
5✔
UNCOV
5985
        }
×
UNCOV
5986
        return nil, err
×
5987
}
5988

5✔
5989
func (stmt *AlterPrivilegesStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
5✔
5990
        return nil
10✔
5991
}
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc