• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 24236186926

10 Apr 2026 09:25AM UTC coverage: 89.169% (-0.09%) from 89.257%
24236186926

push

gh-ci

SimoneLazzaris
fix workflows

38207 of 42848 relevant lines covered (89.17%)

151869.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.18
/embedded/sql/stmt.go
1
/*
2
Copyright 2025 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix          = "CTL."
38
        catalogTablePrefix     = "CTL.TABLE."     // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix    = "CTL.COLUMN."    // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogIndexPrefix     = "CTL.INDEX."     // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
41
        catalogCheckPrefix     = "CTL.CHECK."     // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
42
        catalogPrivilegePrefix = "CTL.PRIVILEGE." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
43

44
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
45
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
46
)
47

48
const (
49
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
50
        PKIndexID  = uint32(0)
51
)
52

53
const (
54
        nullableFlag      byte = 1 << iota
55
        autoIncrementFlag byte = 1 << iota
56
)
57

58
const (
59
        revCol        = "_rev"
60
        txMetadataCol = "_tx_metadata"
61
        diffActionCol = "_diff_action"
62
)
63

64
var reservedColumns = map[string]struct{}{
65
        revCol:        {},
66
        txMetadataCol: {},
67
        diffActionCol: {},
68
}
69

70
func isReservedCol(col string) bool {
17,138✔
71
        _, ok := reservedColumns[col]
17,138✔
72
        return ok
17,138✔
73
}
17,138✔
74

75
type SQLValueType = string
76

77
const (
78
        IntegerType   SQLValueType = "INTEGER"
79
        BooleanType   SQLValueType = "BOOLEAN"
80
        VarcharType   SQLValueType = "VARCHAR"
81
        UUIDType      SQLValueType = "UUID"
82
        BLOBType      SQLValueType = "BLOB"
83
        Float64Type   SQLValueType = "FLOAT"
84
        TimestampType SQLValueType = "TIMESTAMP"
85
        AnyType       SQLValueType = "ANY"
86
        JSONType      SQLValueType = "JSON"
87
)
88

89
func IsNumericType(t SQLValueType) bool {
221✔
90
        return t == IntegerType || t == Float64Type
221✔
91
}
221✔
92

93
type Permission = string
94

95
const (
96
        PermissionReadOnly  Permission = "READ"
97
        PermissionReadWrite Permission = "READWRITE"
98
        PermissionAdmin     Permission = "ADMIN"
99
        PermissionSysAdmin  Permission = "SYSADMIN"
100
)
101

102
func PermissionFromCode(code uint32) Permission {
629✔
103
        switch code {
629✔
104
        case 1:
16✔
105
                {
32✔
106
                        return PermissionReadOnly
16✔
107
                }
16✔
108
        case 2:
10✔
109
                {
20✔
110
                        return PermissionReadWrite
10✔
111
                }
10✔
112
        case 254:
20✔
113
                {
40✔
114
                        return PermissionAdmin
20✔
115
                }
20✔
116
        }
117
        return PermissionSysAdmin
583✔
118
}
119

120
type AggregateFn = string
121

122
const (
123
        COUNT AggregateFn = "COUNT"
124
        SUM   AggregateFn = "SUM"
125
        MAX   AggregateFn = "MAX"
126
        MIN   AggregateFn = "MIN"
127
        AVG   AggregateFn = "AVG"
128
)
129

130
type CmpOperator = int
131

132
const (
133
        EQ CmpOperator = iota
134
        NE
135
        LT
136
        LE
137
        GT
138
        GE
139
)
140

141
func CmpOperatorToString(op CmpOperator) string {
20✔
142
        switch op {
20✔
143
        case EQ:
8✔
144
                return "="
8✔
145
        case NE:
2✔
146
                return "!="
2✔
147
        case LT:
1✔
148
                return "<"
1✔
149
        case LE:
3✔
150
                return "<="
3✔
151
        case GT:
1✔
152
                return ">"
1✔
153
        case GE:
5✔
154
                return ">="
5✔
155
        }
156
        return ""
×
157
}
158

159
type LogicOperator = int
160

161
const (
162
        And LogicOperator = iota
163
        Or
164
)
165

166
func LogicOperatorToString(op LogicOperator) string {
31✔
167
        if op == And {
46✔
168
                return "AND"
15✔
169
        }
15✔
170
        return "OR"
16✔
171
}
172

173
type NumOperator = int
174

175
const (
176
        ADDOP NumOperator = iota
177
        SUBSOP
178
        DIVOP
179
        MULTOP
180
        MODOP
181
)
182

183
func NumOperatorString(op NumOperator) string {
18✔
184
        switch op {
18✔
185
        case ADDOP:
6✔
186
                return "+"
6✔
187
        case SUBSOP:
1✔
188
                return "-"
1✔
189
        case DIVOP:
5✔
190
                return "/"
5✔
191
        case MULTOP:
5✔
192
                return "*"
5✔
193
        case MODOP:
1✔
194
                return "%"
1✔
195
        }
196
        return ""
×
197
}
198

199
type JoinType = int
200

201
const (
202
        InnerJoin JoinType = iota
203
        LeftJoin
204
        RightJoin
205
)
206

207
type SQLStmt interface {
208
        readOnly() bool
209
        requiredPrivileges() []SQLPrivilege
210
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
211
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
212
}
213

214
type BeginTransactionStmt struct {
215
}
216

217
func (stmt *BeginTransactionStmt) readOnly() bool {
9✔
218
        return true
9✔
219
}
9✔
220

221
func (stmt *BeginTransactionStmt) requiredPrivileges() []SQLPrivilege {
9✔
222
        return nil
9✔
223
}
9✔
224

225
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
226
        return nil
3✔
227
}
3✔
228

229
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
61✔
230
        if tx.IsExplicitCloseRequired() {
62✔
231
                return nil, ErrNestedTxNotSupported
1✔
232
        }
1✔
233

234
        err := tx.RequireExplicitClose()
60✔
235
        if err == nil {
119✔
236
                // current tx can be reused as no changes were already made
59✔
237
                return tx, nil
59✔
238
        }
59✔
239

240
        // commit current transaction and start a fresh one
241

242
        err = tx.Commit(ctx)
1✔
243
        if err != nil {
1✔
244
                return nil, err
×
245
        }
×
246

247
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
248
}
249

250
type CommitStmt struct {
251
}
252

253
func (stmt *CommitStmt) readOnly() bool {
113✔
254
        return true
113✔
255
}
113✔
256

257
func (stmt *CommitStmt) requiredPrivileges() []SQLPrivilege {
113✔
258
        return nil
113✔
259
}
113✔
260

261
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
262
        return nil
1✔
263
}
1✔
264

265
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
266
        if !tx.IsExplicitCloseRequired() {
160✔
267
                return nil, ErrNoOngoingTx
1✔
268
        }
1✔
269

270
        return nil, tx.Commit(ctx)
158✔
271
}
272

273
type RollbackStmt struct {
274
}
275

276
func (stmt *RollbackStmt) readOnly() bool {
1✔
277
        return true
1✔
278
}
1✔
279

280
func (stmt *RollbackStmt) requiredPrivileges() []SQLPrivilege {
1✔
281
        return nil
1✔
282
}
1✔
283

284
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
285
        return nil
1✔
286
}
1✔
287

288
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
289
        if !tx.IsExplicitCloseRequired() {
5✔
290
                return nil, ErrNoOngoingTx
1✔
291
        }
1✔
292

293
        return nil, tx.Cancel()
3✔
294
}
295

296
type CreateDatabaseStmt struct {
297
        DB          string
298
        ifNotExists bool
299
}
300

301
func (stmt *CreateDatabaseStmt) readOnly() bool {
14✔
302
        return false
14✔
303
}
14✔
304

305
func (stmt *CreateDatabaseStmt) requiredPrivileges() []SQLPrivilege {
14✔
306
        return []SQLPrivilege{SQLPrivilegeCreate}
14✔
307
}
14✔
308

309
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
310
        return nil
4✔
311
}
4✔
312

313
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
314
        if tx.IsExplicitCloseRequired() {
16✔
315
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
316
        }
1✔
317

318
        if tx.engine.multidbHandler == nil {
16✔
319
                return nil, ErrUnspecifiedMultiDBHandler
2✔
320
        }
2✔
321

322
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
323
}
324

325
type UseDatabaseStmt struct {
326
        DB string
327
}
328

329
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
330
        return nil
1✔
331
}
1✔
332

333
func (stmt *UseDatabaseStmt) readOnly() bool {
9✔
334
        return true
9✔
335
}
9✔
336

337
func (stmt *UseDatabaseStmt) requiredPrivileges() []SQLPrivilege {
9✔
338
        return nil
9✔
339
}
9✔
340

341
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
342
        if tx.IsExplicitCloseRequired() {
10✔
343
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
344
        }
1✔
345

346
        if tx.engine.multidbHandler == nil {
9✔
347
                return nil, ErrUnspecifiedMultiDBHandler
1✔
348
        }
1✔
349

350
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
351
}
352

353
type UseSnapshotStmt struct {
354
        period period
355
}
356

357
func (stmt *UseSnapshotStmt) readOnly() bool {
1✔
358
        return true
1✔
359
}
1✔
360

361
func (stmt *UseSnapshotStmt) requiredPrivileges() []SQLPrivilege {
1✔
362
        return nil
1✔
363
}
1✔
364

365
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
366
        return nil
1✔
367
}
1✔
368

369
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
370
        return nil, ErrNoSupported
1✔
371
}
1✔
372

373
type CreateUserStmt struct {
374
        username   string
375
        password   string
376
        permission Permission
377
}
378

379
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
380
        return nil
1✔
381
}
1✔
382

383
func (stmt *CreateUserStmt) readOnly() bool {
5✔
384
        return false
5✔
385
}
5✔
386

387
func (stmt *CreateUserStmt) requiredPrivileges() []SQLPrivilege {
5✔
388
        return []SQLPrivilege{SQLPrivilegeCreate}
5✔
389
}
5✔
390

391
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
392
        if tx.IsExplicitCloseRequired() {
7✔
393
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
394
        }
1✔
395

396
        if tx.engine.multidbHandler == nil {
6✔
397
                return nil, ErrUnspecifiedMultiDBHandler
1✔
398
        }
1✔
399

400
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
401
}
402

403
type AlterUserStmt struct {
404
        username   string
405
        password   string
406
        permission Permission
407
}
408

409
func (stmt *AlterUserStmt) readOnly() bool {
4✔
410
        return false
4✔
411
}
4✔
412

413
func (stmt *AlterUserStmt) requiredPrivileges() []SQLPrivilege {
4✔
414
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
415
}
4✔
416

417
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
418
        return nil
1✔
419
}
1✔
420

421
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
422
        if tx.IsExplicitCloseRequired() {
6✔
423
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
424
        }
1✔
425

426
        if tx.engine.multidbHandler == nil {
5✔
427
                return nil, ErrUnspecifiedMultiDBHandler
1✔
428
        }
1✔
429

430
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
431
}
432

433
type DropUserStmt struct {
434
        username string
435
}
436

437
func (stmt *DropUserStmt) readOnly() bool {
3✔
438
        return false
3✔
439
}
3✔
440

441
func (stmt *DropUserStmt) requiredPrivileges() []SQLPrivilege {
3✔
442
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
443
}
3✔
444

445
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
446
        return nil
1✔
447
}
1✔
448

449
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
450
        if tx.IsExplicitCloseRequired() {
4✔
451
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
452
        }
1✔
453

454
        if tx.engine.multidbHandler == nil {
3✔
455
                return nil, ErrUnspecifiedMultiDBHandler
1✔
456
        }
1✔
457

458
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
459
}
460

461
type TableElem interface{}
462

463
type CreateTableStmt struct {
464
        table       string
465
        ifNotExists bool
466
        colsSpec    []*ColSpec
467
        checks      []CheckConstraint
468
        pkColNames  PrimaryKeyConstraint
469
}
470

471
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
472
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
473
}
38✔
474

475
func (stmt *CreateTableStmt) readOnly() bool {
66✔
476
        return false
66✔
477
}
66✔
478

479
func (stmt *CreateTableStmt) requiredPrivileges() []SQLPrivilege {
64✔
480
        return []SQLPrivilege{SQLPrivilegeCreate}
64✔
481
}
64✔
482

483
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
484
        return nil
4✔
485
}
4✔
486

487
func zeroRow(tableName string, cols []*ColSpec) *Row {
246✔
488
        r := Row{
246✔
489
                ValuesByPosition: make([]TypedValue, len(cols)),
246✔
490
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
246✔
491
        }
246✔
492

246✔
493
        for i, col := range cols {
1,031✔
494
                v := zeroForType(col.colType)
785✔
495

785✔
496
                r.ValuesByPosition[i] = v
785✔
497
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
785✔
498
        }
785✔
499
        return &r
246✔
500
}
501

502
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
237✔
503
        if err := stmt.validatePrimaryKey(); err != nil {
240✔
504
                return nil, err
3✔
505
        }
3✔
506

507
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
235✔
508
                return tx, nil
1✔
509
        }
1✔
510

511
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
233✔
512
        for i, cs := range stmt.colsSpec {
958✔
513
                colSpecs[uint32(i)+1] = cs
725✔
514
        }
725✔
515

516
        row := zeroRow(stmt.table, stmt.colsSpec)
233✔
517
        for _, check := range stmt.checks {
242✔
518
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
519
                if err != nil {
11✔
520
                        return nil, err
2✔
521
                }
2✔
522

523
                if value.Type() != BooleanType {
7✔
524
                        return nil, ErrInvalidCheckConstraint
×
525
                }
×
526
        }
527

528
        nextUnnamedCheck := 0
231✔
529
        checks := make(map[string]CheckConstraint)
231✔
530
        for id, check := range stmt.checks {
238✔
531
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
532
                if check.name != "" {
9✔
533
                        name = check.name
2✔
534
                } else {
7✔
535
                        nextUnnamedCheck++
5✔
536
                }
5✔
537
                check.id = uint32(id)
7✔
538
                check.name = name
7✔
539
                checks[name] = check
7✔
540
        }
541

542
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
231✔
543
        if err != nil {
237✔
544
                return nil, err
6✔
545
        }
6✔
546

547
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.primaryKeyCols()}
225✔
548
        _, err = createIndexStmt.execAt(ctx, tx, params)
225✔
549
        if err != nil {
230✔
550
                return nil, err
5✔
551
        }
5✔
552

553
        for _, col := range table.cols {
926✔
554
                if col.autoIncrement {
783✔
555
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
78✔
556
                                return nil, ErrLimitedAutoIncrement
1✔
557
                        }
1✔
558
                }
559

560
                err := persistColumn(tx, col)
705✔
561
                if err != nil {
705✔
562
                        return nil, err
×
563
                }
×
564
        }
565

566
        for _, check := range checks {
226✔
567
                if err := persistCheck(tx, table, &check); err != nil {
7✔
568
                        return nil, err
×
569
                }
×
570
        }
571

572
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
219✔
573

219✔
574
        err = tx.set(mappedKey, nil, []byte(table.name))
219✔
575
        if err != nil {
219✔
576
                return nil, err
×
577
        }
×
578

579
        tx.mutatedCatalog = true
219✔
580

219✔
581
        return tx, nil
219✔
582
}
583

584
func (stmt *CreateTableStmt) validatePrimaryKey() error {
237✔
585
        n := 0
237✔
586
        for _, spec := range stmt.colsSpec {
969✔
587
                if spec.primaryKey {
737✔
588
                        n++
5✔
589
                }
5✔
590
        }
591

592
        if len(stmt.pkColNames) > 0 {
470✔
593
                n++
233✔
594
        }
233✔
595

596
        switch n {
237✔
597
        case 0:
1✔
598
                return ErrNoPrimaryKey
1✔
599
        case 1:
234✔
600
                return nil
234✔
601
        }
602
        return fmt.Errorf("\"%s\": %w", stmt.table, ErrMultiplePrimaryKeys)
2✔
603
}
604

605
func (stmt *CreateTableStmt) primaryKeyCols() []string {
225✔
606
        if len(stmt.pkColNames) > 0 {
448✔
607
                return stmt.pkColNames
223✔
608
        }
223✔
609

610
        for _, spec := range stmt.colsSpec {
4✔
611
                if spec.primaryKey {
4✔
612
                        return []string{spec.colName}
2✔
613
                }
2✔
614
        }
615
        return nil
×
616
}
617

618
func persistColumn(tx *SQLTx, col *Column) error {
725✔
619
        //{auto_incremental | nullable}{maxLen}{colNAME})
725✔
620
        v := make([]byte, 1+4+len(col.colName))
725✔
621

725✔
622
        if col.autoIncrement {
801✔
623
                v[0] = v[0] | autoIncrementFlag
76✔
624
        }
76✔
625

626
        if col.notNull {
772✔
627
                v[0] = v[0] | nullableFlag
47✔
628
        }
47✔
629

630
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
725✔
631

725✔
632
        copy(v[5:], []byte(col.Name()))
725✔
633

725✔
634
        mappedKey := MapKey(
725✔
635
                tx.sqlPrefix(),
725✔
636
                catalogColumnPrefix,
725✔
637
                EncodeID(DatabaseID),
725✔
638
                EncodeID(col.table.id),
725✔
639
                EncodeID(col.id),
725✔
640
                []byte(col.colType),
725✔
641
        )
725✔
642

725✔
643
        return tx.set(mappedKey, nil, v)
725✔
644
}
645

646
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
647
        mappedKey := MapKey(
7✔
648
                tx.sqlPrefix(),
7✔
649
                catalogCheckPrefix,
7✔
650
                EncodeID(DatabaseID),
7✔
651
                EncodeID(table.id),
7✔
652
                EncodeID(check.id),
7✔
653
        )
7✔
654

7✔
655
        name := check.name
7✔
656
        expText := check.exp.String()
7✔
657

7✔
658
        val := make([]byte, 2+len(name)+len(expText))
7✔
659

7✔
660
        if len(name) > 256 {
7✔
661
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
662
        }
×
663

664
        val[0] = byte(len(name)) - 1
7✔
665

7✔
666
        copy(val[1:], []byte(name))
7✔
667
        copy(val[1+len(name):], []byte(expText))
7✔
668

7✔
669
        return tx.set(mappedKey, nil, val)
7✔
670
}
671

672
type ColSpec struct {
673
        colName       string
674
        colType       SQLValueType
675
        maxLen        int
676
        autoIncrement bool
677
        notNull       bool
678
        primaryKey    bool
679
}
680

681
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
682
        return &ColSpec{
188✔
683
                colName:       name,
188✔
684
                colType:       colType,
188✔
685
                maxLen:        maxLen,
188✔
686
                autoIncrement: autoIncrement,
188✔
687
                notNull:       notNull,
188✔
688
        }
188✔
689
}
188✔
690

691
type CreateIndexStmt struct {
692
        unique      bool
693
        ifNotExists bool
694
        table       string
695
        cols        []string
696
}
697

698
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
699
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
700
}
72✔
701

702
func (stmt *CreateIndexStmt) readOnly() bool {
7✔
703
        return false
7✔
704
}
7✔
705

706
func (stmt *CreateIndexStmt) requiredPrivileges() []SQLPrivilege {
7✔
707
        return []SQLPrivilege{SQLPrivilegeCreate}
7✔
708
}
7✔
709

710
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
711
        return nil
1✔
712
}
1✔
713

714
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
376✔
715
        if len(stmt.cols) < 1 {
377✔
716
                return nil, ErrIllegalArguments
1✔
717
        }
1✔
718

719
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
376✔
720
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
721
        }
1✔
722

723
        table, err := tx.catalog.GetTableByName(stmt.table)
374✔
724
        if err != nil {
376✔
725
                return nil, err
2✔
726
        }
2✔
727

728
        colIDs := make([]uint32, len(stmt.cols))
372✔
729

372✔
730
        indexKeyLen := 0
372✔
731

372✔
732
        for i, colName := range stmt.cols {
771✔
733
                col, err := table.GetColumnByName(colName)
399✔
734
                if err != nil {
404✔
735
                        return nil, err
5✔
736
                }
5✔
737

738
                if col.Type() == JSONType {
396✔
739
                        return nil, ErrCannotIndexJson
2✔
740
                }
2✔
741

742
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
394✔
743
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
744
                }
2✔
745

746
                indexKeyLen += col.MaxLen()
390✔
747

390✔
748
                colIDs[i] = col.id
390✔
749
        }
750

751
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
363✔
752
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
753
        }
×
754

755
        if stmt.unique && table.primaryIndex != nil {
383✔
756
                // check table is empty
20✔
757
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
758
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
759
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
760
                        return nil, ErrTableDoesNotExist
×
761
                }
×
762
                if err == nil {
21✔
763
                        return nil, ErrLimitedIndexCreation
1✔
764
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
765
                        return nil, err
×
766
                }
×
767
        }
768

769
        index, err := table.newIndex(stmt.unique, colIDs)
362✔
770
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
364✔
771
                return tx, nil
2✔
772
        }
2✔
773
        if err != nil {
364✔
774
                return nil, err
4✔
775
        }
4✔
776

777
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
778
        // TODO: currently only ASC order is supported
779
        colSpecLen := EncIDLen + 1
356✔
780

356✔
781
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
356✔
782

356✔
783
        if index.IsUnique() {
594✔
784
                encodedValues[0] = 1
238✔
785
        }
238✔
786

787
        for i, col := range index.cols {
739✔
788
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
383✔
789
        }
383✔
790

791
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
356✔
792

356✔
793
        err = tx.set(mappedKey, nil, encodedValues)
356✔
794
        if err != nil {
356✔
795
                return nil, err
×
796
        }
×
797

798
        tx.mutatedCatalog = true
356✔
799

356✔
800
        return tx, nil
356✔
801
}
802

803
type AddColumnStmt struct {
804
        table   string
805
        colSpec *ColSpec
806
}
807

808
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
809
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
810
}
6✔
811

812
func (stmt *AddColumnStmt) readOnly() bool {
4✔
813
        return false
4✔
814
}
4✔
815

816
func (stmt *AddColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
817
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
818
}
4✔
819

820
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
821
        return nil
1✔
822
}
1✔
823

824
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
825
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
826
        if err != nil {
20✔
827
                return nil, err
1✔
828
        }
1✔
829

830
        col, err := table.newColumn(stmt.colSpec)
18✔
831
        if err != nil {
24✔
832
                return nil, err
6✔
833
        }
6✔
834

835
        err = persistColumn(tx, col)
12✔
836
        if err != nil {
12✔
837
                return nil, err
×
838
        }
×
839

840
        tx.mutatedCatalog = true
12✔
841

12✔
842
        return tx, nil
12✔
843
}
844

845
type RenameTableStmt struct {
846
        oldName string
847
        newName string
848
}
849

850
func (stmt *RenameTableStmt) readOnly() bool {
1✔
851
        return false
1✔
852
}
1✔
853

854
func (stmt *RenameTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
855
        return []SQLPrivilege{SQLPrivilegeAlter}
1✔
856
}
1✔
857

858
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
859
        return nil
1✔
860
}
1✔
861

862
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
863
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
864
        if err != nil {
10✔
865
                return nil, err
4✔
866
        }
4✔
867

868
        // update table name
869
        mappedKey := MapKey(
2✔
870
                tx.sqlPrefix(),
2✔
871
                catalogTablePrefix,
2✔
872
                EncodeID(DatabaseID),
2✔
873
                EncodeID(table.id),
2✔
874
        )
2✔
875
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
876
        if err != nil {
2✔
877
                return nil, err
×
878
        }
×
879

880
        tx.mutatedCatalog = true
2✔
881

2✔
882
        return tx, nil
2✔
883
}
884

885
type RenameColumnStmt struct {
886
        table   string
887
        oldName string
888
        newName string
889
}
890

891
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
892
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
893
}
3✔
894

895
func (stmt *RenameColumnStmt) readOnly() bool {
4✔
896
        return false
4✔
897
}
4✔
898

899
func (stmt *RenameColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
900
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
901
}
4✔
902

903
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
904
        return nil
1✔
905
}
1✔
906

907
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
908
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
909
        if err != nil {
11✔
910
                return nil, err
1✔
911
        }
1✔
912

913
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
914
        if err != nil {
12✔
915
                return nil, err
3✔
916
        }
3✔
917

918
        err = persistColumn(tx, col)
6✔
919
        if err != nil {
6✔
920
                return nil, err
×
921
        }
×
922

923
        tx.mutatedCatalog = true
6✔
924

6✔
925
        return tx, nil
6✔
926
}
927

928
type DropColumnStmt struct {
929
        table   string
930
        colName string
931
}
932

933
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
934
        return &DropColumnStmt{table: table, colName: colName}
8✔
935
}
8✔
936

937
func (stmt *DropColumnStmt) readOnly() bool {
2✔
938
        return false
2✔
939
}
2✔
940

941
func (stmt *DropColumnStmt) requiredPrivileges() []SQLPrivilege {
2✔
942
        return []SQLPrivilege{SQLPrivilegeDrop}
2✔
943
}
2✔
944

945
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
946
        return nil
1✔
947
}
1✔
948

949
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
950
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
951
        if err != nil {
21✔
952
                return nil, err
2✔
953
        }
2✔
954

955
        col, err := table.GetColumnByName(stmt.colName)
17✔
956
        if err != nil {
21✔
957
                return nil, err
4✔
958
        }
4✔
959

960
        err = canDropColumn(tx, table, col)
13✔
961
        if err != nil {
14✔
962
                return nil, err
1✔
963
        }
1✔
964

965
        err = table.deleteColumn(col)
12✔
966
        if err != nil {
16✔
967
                return nil, err
4✔
968
        }
4✔
969

970
        err = persistColumnDeletion(ctx, tx, col)
8✔
971
        if err != nil {
8✔
972
                return nil, err
×
973
        }
×
974

975
        tx.mutatedCatalog = true
8✔
976

8✔
977
        return tx, nil
8✔
978
}
979

980
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
981
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
982
        for _, c := range table.cols {
86✔
983
                if c.id != col.id {
133✔
984
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
985
                }
60✔
986
        }
987

988
        row := zeroRow(table.Name(), colSpecs)
13✔
989
        for name, check := range table.checkConstraints {
17✔
990
                _, err := check.exp.reduce(tx, row, table.name)
4✔
991
                if errors.Is(err, ErrColumnDoesNotExist) {
5✔
992
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
993
                }
1✔
994

995
                if err != nil {
3✔
996
                        return err
×
997
                }
×
998
        }
999
        return nil
12✔
1000
}
1001

1002
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
1003
        mappedKey := MapKey(
9✔
1004
                tx.sqlPrefix(),
9✔
1005
                catalogColumnPrefix,
9✔
1006
                EncodeID(DatabaseID),
9✔
1007
                EncodeID(col.table.id),
9✔
1008
                EncodeID(col.id),
9✔
1009
                []byte(col.colType),
9✔
1010
        )
9✔
1011

9✔
1012
        return tx.delete(ctx, mappedKey)
9✔
1013
}
9✔
1014

1015
type DropConstraintStmt struct {
1016
        table          string
1017
        constraintName string
1018
}
1019

1020
func (stmt *DropConstraintStmt) readOnly() bool {
×
1021
        return false
×
1022
}
×
1023

1024
func (stmt *DropConstraintStmt) requiredPrivileges() []SQLPrivilege {
×
1025
        return []SQLPrivilege{SQLPrivilegeDrop}
×
1026
}
×
1027

1028
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
1029
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
1030
        if err != nil {
4✔
1031
                return nil, err
×
1032
        }
×
1033

1034
        id, err := table.deleteCheck(stmt.constraintName)
4✔
1035
        if err != nil {
5✔
1036
                return nil, err
1✔
1037
        }
1✔
1038

1039
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
1040

3✔
1041
        tx.mutatedCatalog = true
3✔
1042

3✔
1043
        return tx, err
3✔
1044
}
1045

1046
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
1047
        mappedKey := MapKey(
3✔
1048
                tx.sqlPrefix(),
3✔
1049
                catalogCheckPrefix,
3✔
1050
                EncodeID(DatabaseID),
3✔
1051
                EncodeID(tableID),
3✔
1052
                EncodeID(checkId),
3✔
1053
        )
3✔
1054
        return tx.delete(ctx, mappedKey)
3✔
1055
}
3✔
1056

1057
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
1058
        return nil
×
1059
}
×
1060

1061
type UpsertIntoStmt struct {
1062
        isInsert   bool
1063
        tableRef   *tableRef
1064
        cols       []string
1065
        ds         DataSource
1066
        onConflict *OnConflictDo
1067
}
1068

1069
func (stmt *UpsertIntoStmt) readOnly() bool {
101✔
1070
        return false
101✔
1071
}
101✔
1072

1073
func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege {
101✔
1074
        privileges := stmt.privileges()
101✔
1075
        if stmt.ds != nil {
200✔
1076
                privileges = append(privileges, stmt.ds.requiredPrivileges()...)
99✔
1077
        }
99✔
1078
        return privileges
101✔
1079
}
1080

1081
func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege {
101✔
1082
        if stmt.isInsert {
190✔
1083
                return []SQLPrivilege{SQLPrivilegeInsert}
89✔
1084
        }
89✔
1085
        return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate}
12✔
1086
}
1087

1088
func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
1089
        return &UpsertIntoStmt{
120✔
1090
                isInsert:   isInsert,
120✔
1091
                tableRef:   NewTableRef(table, ""),
120✔
1092
                cols:       cols,
120✔
1093
                ds:         ds,
120✔
1094
                onConflict: onConflict,
120✔
1095
        }
120✔
1096
}
120✔
1097

1098
type RowSpec struct {
1099
        Values []ValueExp
1100
}
1101

1102
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
1103
        return &RowSpec{
129✔
1104
                Values: values,
129✔
1105
        }
129✔
1106
}
129✔
1107

1108
type OnConflictDo struct{}
1109

1110
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
1111
        ds, ok := stmt.ds.(*valuesDataSource)
11✔
1112
        if !ok {
11✔
1113
                return stmt.ds.inferParameters(ctx, tx, params)
×
1114
        }
×
1115

1116
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
1117
        for _, row := range ds.rows {
23✔
1118
                if len(stmt.cols) != len(row.Values) {
13✔
1119
                        return ErrInvalidNumberOfValues
1✔
1120
                }
1✔
1121

1122
                for i, val := range row.Values {
36✔
1123
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
1124
                        if err != nil {
26✔
1125
                                return err
1✔
1126
                        }
1✔
1127

1128
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
1129
                        if err != nil {
25✔
1130
                                return err
1✔
1131
                        }
1✔
1132

1133
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
1134
                        if err != nil {
25✔
1135
                                return err
2✔
1136
                        }
2✔
1137
                }
1138
        }
1139
        return nil
6✔
1140
}
1141

1142
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
2,354✔
1143
        selPosByColID := make(map[uint32]int, len(stmt.cols))
2,354✔
1144

2,354✔
1145
        for i, c := range stmt.cols {
11,387✔
1146
                col, err := table.GetColumnByName(c)
9,033✔
1147
                if err != nil {
9,035✔
1148
                        return nil, err
2✔
1149
                }
2✔
1150

1151
                _, duplicated := selPosByColID[col.id]
9,031✔
1152
                if duplicated {
9,032✔
1153
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
1154
                }
1✔
1155

1156
                selPosByColID[col.id] = i
9,030✔
1157
        }
1158

1159
        return selPosByColID, nil
2,351✔
1160
}
1161

1162
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2,357✔
1163
        table, err := stmt.tableRef.referencedTable(tx)
2,357✔
1164
        if err != nil {
2,360✔
1165
                return nil, err
3✔
1166
        }
3✔
1167

1168
        selPosByColID, err := stmt.validate(table)
2,354✔
1169
        if err != nil {
2,357✔
1170
                return nil, err
3✔
1171
        }
3✔
1172

1173
        r := &Row{
2,351✔
1174
                ValuesByPosition: make([]TypedValue, len(table.cols)),
2,351✔
1175
                ValuesBySelector: make(map[string]TypedValue),
2,351✔
1176
        }
2,351✔
1177

2,351✔
1178
        reader, err := stmt.ds.Resolve(ctx, tx, params, nil)
2,351✔
1179
        if err != nil {
2,351✔
1180
                return nil, err
×
1181
        }
×
1182
        defer reader.Close()
2,351✔
1183

2,351✔
1184
        for {
7,207✔
1185
                row, err := reader.Read(ctx)
4,856✔
1186
                if errors.Is(err, ErrNoMoreRows) {
7,164✔
1187
                        break
2,308✔
1188
                }
1189
                if err != nil {
2,558✔
1190
                        return nil, err
10✔
1191
                }
10✔
1192

1193
                if len(row.ValuesByPosition) != len(stmt.cols) {
2,540✔
1194
                        return nil, ErrInvalidNumberOfValues
2✔
1195
                }
2✔
1196

1197
                valuesByColID := make(map[uint32]TypedValue)
2,536✔
1198

2,536✔
1199
                var pkMustExist bool
2,536✔
1200

2,536✔
1201
                for colID, col := range table.colsByID {
13,397✔
1202
                        colPos, specified := selPosByColID[colID]
10,861✔
1203
                        if !specified {
12,109✔
1204
                                // TODO: Default values
1,248✔
1205
                                if col.notNull && !col.autoIncrement {
1,249✔
1206
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
1207
                                }
1✔
1208

1209
                                // inject auto-incremental pk value
1210
                                if stmt.isInsert && col.autoIncrement {
2,305✔
1211
                                        // current implementation assumes only PK can be set as autoincremental
1,058✔
1212
                                        table.maxPK++
1,058✔
1213

1,058✔
1214
                                        pkCol := table.primaryIndex.cols[0]
1,058✔
1215
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
1,058✔
1216

1,058✔
1217
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,924✔
1218
                                                tx.firstInsertedPKs[table.name] = table.maxPK
866✔
1219
                                        }
866✔
1220
                                        tx.lastInsertedPKs[table.name] = table.maxPK
1,058✔
1221
                                }
1222

1223
                                continue
1,247✔
1224
                        }
1225

1226
                        // value was specified
1227
                        cVal := row.ValuesByPosition[colPos]
9,613✔
1228

9,613✔
1229
                        val, err := cVal.substitute(params)
9,613✔
1230
                        if err != nil {
9,613✔
1231
                                return nil, err
×
1232
                        }
×
1233

1234
                        rval, err := val.reduce(tx, nil, table.name)
9,613✔
1235
                        if err != nil {
9,613✔
1236
                                return nil, err
×
1237
                        }
×
1238

1239
                        if rval.IsNull() {
9,719✔
1240
                                if col.notNull || col.autoIncrement {
106✔
1241
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1242
                                }
×
1243

1244
                                continue
106✔
1245
                        }
1246

1247
                        if col.autoIncrement {
9,526✔
1248
                                // validate specified value
19✔
1249
                                nl, isNumber := rval.RawValue().(int64)
19✔
1250
                                if !isNumber {
19✔
1251
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1252
                                }
×
1253

1254
                                pkMustExist = nl <= table.maxPK
19✔
1255

19✔
1256
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1257
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1258
                                }
19✔
1259
                                tx.lastInsertedPKs[table.name] = nl
19✔
1260
                        }
1261

1262
                        valuesByColID[colID] = rval
9,507✔
1263
                }
1264

1265
                for i, col := range table.cols {
13,392✔
1266
                        v := valuesByColID[col.id]
10,857✔
1267

10,857✔
1268
                        if v == nil {
11,151✔
1269
                                v = NewNull(AnyType)
294✔
1270
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
10,862✔
1271
                                s, _ := v.RawValue().(string)
5✔
1272
                                jsonVal, err := NewJsonFromString(s)
5✔
1273
                                if err != nil {
5✔
1274
                                        return nil, err
×
1275
                                }
×
1276
                                v = jsonVal
5✔
1277
                        }
1278

1279
                        r.ValuesByPosition[i] = v
10,857✔
1280
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
10,857✔
1281
                }
1282

1283
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
2,541✔
1284
                        return nil, err
6✔
1285
                }
6✔
1286

1287
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
2,529✔
1288
                if err != nil {
2,534✔
1289
                        return nil, err
5✔
1290
                }
5✔
1291

1292
                // pk entry
1293
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
2,524✔
1294
                if len(mappedPKey) > MaxKeyLen {
2,524✔
1295
                        return nil, ErrMaxKeyLengthExceeded
×
1296
                }
×
1297

1298
                _, err = tx.get(ctx, mappedPKey)
2,524✔
1299
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
2,524✔
1300
                        return nil, err
×
1301
                }
×
1302

1303
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
2,526✔
1304
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1305
                }
2✔
1306

1307
                if stmt.isInsert {
4,863✔
1308
                        if err == nil && stmt.onConflict == nil {
2,345✔
1309
                                return nil, store.ErrKeyAlreadyExists
4✔
1310
                        }
4✔
1311

1312
                        if err == nil && stmt.onConflict != nil {
2,340✔
1313
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1314
                                continue
3✔
1315
                        }
1316
                }
1317

1318
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
2,515✔
1319
                if err != nil {
2,528✔
1320
                        return nil, err
13✔
1321
                }
13✔
1322
        }
1323
        return tx, nil
2,308✔
1324
}
1325

1326
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
2,571✔
1327
        for _, check := range checks {
2,619✔
1328
                val, err := check.exp.reduce(tx, row, table)
48✔
1329
                if err != nil {
49✔
1330
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1331
                }
1✔
1332

1333
                if val.Type() != BooleanType {
47✔
1334
                        return ErrInvalidCheckConstraint
×
1335
                }
×
1336

1337
                if !val.RawValue().(bool) {
54✔
1338
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1339
                }
7✔
1340
        }
1341
        return nil
2,563✔
1342
}
1343

1344
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,708✔
1345
        valbuf := bytes.Buffer{}
2,708✔
1346

2,708✔
1347
        // null values are not serialized
2,708✔
1348
        encodedVals := 0
2,708✔
1349
        for _, v := range valuesByColID {
13,703✔
1350
                if !v.IsNull() {
21,972✔
1351
                        encodedVals++
10,977✔
1352
                }
10,977✔
1353
        }
1354

1355
        b := make([]byte, EncLenLen)
2,708✔
1356
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,708✔
1357

2,708✔
1358
        _, err := valbuf.Write(b)
2,708✔
1359
        if err != nil {
2,708✔
1360
                return nil, err
×
1361
        }
×
1362

1363
        for _, col := range table.cols {
13,973✔
1364
                rval, specified := valuesByColID[col.id]
11,265✔
1365
                if !specified || rval.IsNull() {
11,559✔
1366
                        continue
294✔
1367
                }
1368

1369
                b := make([]byte, EncIDLen)
10,971✔
1370
                binary.BigEndian.PutUint32(b, uint32(col.id))
10,971✔
1371

10,971✔
1372
                _, err = valbuf.Write(b)
10,971✔
1373
                if err != nil {
10,971✔
1374
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1375
                }
×
1376

1377
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
10,971✔
1378
                if err != nil {
10,979✔
1379
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1380
                }
8✔
1381

1382
                _, err = valbuf.Write(encVal)
10,963✔
1383
                if err != nil {
10,963✔
1384
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1385
                }
×
1386
        }
1387

1388
        return valbuf.Bytes(), nil
2,700✔
1389
}
1390

1391
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
2,549✔
1392
        var reusableIndexEntries map[uint32]struct{}
2,549✔
1393

2,549✔
1394
        if reuseIndex && len(table.indexes) > 1 {
2,606✔
1395
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1396
                if err == nil {
93✔
1397
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1398

36✔
1399
                        for _, col := range table.cols {
161✔
1400
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1401
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1402
                        }
125✔
1403

1404
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1405
                        if err != nil {
36✔
1406
                                return err
×
1407
                        }
×
1408
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1409
                        return err
×
1410
                }
×
1411
        }
1412

1413
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
2,549✔
1414

2,549✔
1415
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
2,549✔
1416
        if err != nil {
2,557✔
1417
                return err
8✔
1418
        }
8✔
1419

1420
        err = tx.set(rowKey, nil, encodedRowValue)
2,541✔
1421
        if err != nil {
2,541✔
1422
                return err
×
1423
        }
×
1424

1425
        // create in-memory and validate entries for secondary indexes
1426
        for _, index := range table.indexes {
5,983✔
1427
                if index.IsPrimary() {
5,983✔
1428
                        continue
2,541✔
1429
                }
1430

1431
                if reusableIndexEntries != nil {
978✔
1432
                        _, reusable := reusableIndexEntries[index.id]
77✔
1433
                        if reusable {
127✔
1434
                                continue
50✔
1435
                        }
1436
                }
1437

1438
                encodedValues := make([][]byte, 2+len(index.cols))
851✔
1439
                encodedValues[0] = EncodeID(table.id)
851✔
1440
                encodedValues[1] = EncodeID(index.id)
851✔
1441

851✔
1442
                indexKeyLen := 0
851✔
1443

851✔
1444
                for i, col := range index.cols {
1,783✔
1445
                        rval, specified := valuesByColID[col.id]
932✔
1446
                        if !specified {
1,017✔
1447
                                rval = &NullValue{t: col.colType}
85✔
1448
                        }
85✔
1449

1450
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
932✔
1451
                        if err != nil {
932✔
1452
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1453
                        }
×
1454

1455
                        if n > MaxKeyLen {
932✔
1456
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1457
                        }
×
1458

1459
                        indexKeyLen += n
932✔
1460

932✔
1461
                        encodedValues[i+2] = encVal
932✔
1462
                }
1463

1464
                if indexKeyLen > MaxKeyLen {
851✔
1465
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1466
                }
×
1467

1468
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
851✔
1469

851✔
1470
                // no other equivalent entry should be already indexed
851✔
1471
                if index.IsUnique() {
928✔
1472
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1473
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1474
                                return store.ErrKeyAlreadyExists
5✔
1475
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1476
                                return err
×
1477
                        }
×
1478
                }
1479

1480
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
846✔
1481
                if err != nil {
846✔
1482
                        return err
×
1483
                }
×
1484
        }
1485

1486
        tx.updatedRows++
2,536✔
1487

2,536✔
1488
        return nil
2,536✔
1489
}
1490

1491
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
16,107✔
1492
        valbuf := bytes.Buffer{}
16,107✔
1493

16,107✔
1494
        indexKeyLen := 0
16,107✔
1495

16,107✔
1496
        for _, col := range index.cols {
32,226✔
1497
                rval, specified := valuesByColID[col.id]
16,119✔
1498
                if !specified || rval.IsNull() {
16,122✔
1499
                        return nil, ErrPKCanNotBeNull
3✔
1500
                }
3✔
1501

1502
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
16,116✔
1503
                if err != nil {
16,118✔
1504
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1505
                }
2✔
1506

1507
                if n > MaxKeyLen {
16,114✔
1508
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1509
                }
×
1510

1511
                indexKeyLen += n
16,114✔
1512

16,114✔
1513
                _, err = valbuf.Write(encVal)
16,114✔
1514
                if err != nil {
16,114✔
1515
                        return nil, err
×
1516
                }
×
1517
        }
1518

1519
        if indexKeyLen > MaxKeyLen {
16,102✔
1520
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1521
        }
×
1522

1523
        return valbuf.Bytes(), nil
16,102✔
1524
}
1525

1526
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1527
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1528

57✔
1529
        for _, pkCol := range table.primaryIndex.cols {
114✔
1530
                pkVal := valuesByColID[pkCol.id]
57✔
1531

57✔
1532
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1533
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1534
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1535
                }
57✔
1536
        }
57✔
1537

1538
        scanSpecs := &ScanSpecs{
57✔
1539
                Index:         table.primaryIndex,
57✔
1540
                rangesByColID: pkRanges,
57✔
1541
        }
57✔
1542

57✔
1543
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1544
        if err != nil {
57✔
1545
                return nil, err
×
1546
        }
×
1547

1548
        defer func() {
114✔
1549
                r.Close()
57✔
1550
        }()
57✔
1551

1552
        return r.Read(ctx)
57✔
1553
}
1554

1555
// deprecateIndexEntries mark previous index entries as deleted
1556
func (tx *SQLTx) deprecateIndexEntries(
1557
        pkEncVals []byte,
1558
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1559
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1560

36✔
1561
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1562
        if err != nil {
36✔
1563
                return nil, err
×
1564
        }
×
1565

1566
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1567

36✔
1568
        for _, index := range table.indexes {
149✔
1569
                if index.IsPrimary() {
149✔
1570
                        continue
36✔
1571
                }
1572

1573
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1574
                encodedValues[0] = EncodeID(table.id)
77✔
1575
                encodedValues[1] = EncodeID(index.id)
77✔
1576
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1577

77✔
1578
                // existent index entry is deleted only if it differs from existent one
77✔
1579
                sameIndexKey := true
77✔
1580

77✔
1581
                for i, col := range index.cols {
159✔
1582
                        currVal, specified := currValuesByColID[col.id]
82✔
1583
                        if !specified {
82✔
1584
                                currVal = &NullValue{t: col.colType}
×
1585
                        }
×
1586

1587
                        newVal, specified := newValuesByColID[col.id]
82✔
1588
                        if !specified {
86✔
1589
                                newVal = &NullValue{t: col.colType}
4✔
1590
                        }
4✔
1591

1592
                        r, err := currVal.Compare(newVal)
82✔
1593
                        if err != nil {
82✔
1594
                                return nil, err
×
1595
                        }
×
1596

1597
                        sameIndexKey = sameIndexKey && r == 0
82✔
1598

82✔
1599
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1600

82✔
1601
                        encodedValues[i+3] = encVal
82✔
1602
                }
1603

1604
                // mark existent index entry as deleted
1605
                if sameIndexKey {
127✔
1606
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1607
                } else {
77✔
1608
                        md := store.NewKVMetadata()
27✔
1609

27✔
1610
                        md.AsDeleted(true)
27✔
1611

27✔
1612
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1613
                        if err != nil {
27✔
1614
                                return nil, err
×
1615
                        }
×
1616
                }
1617
        }
1618

1619
        return reusableIndexEntries, nil
36✔
1620
}
1621

1622
type UpdateStmt struct {
1623
        tableRef *tableRef
1624
        where    ValueExp
1625
        updates  []*colUpdate
1626
        indexOn  []string
1627
        limit    ValueExp
1628
        offset   ValueExp
1629
}
1630

1631
type colUpdate struct {
1632
        col string
1633
        op  CmpOperator
1634
        val ValueExp
1635
}
1636

1637
func (stmt *UpdateStmt) readOnly() bool {
4✔
1638
        return false
4✔
1639
}
4✔
1640

1641
func (stmt *UpdateStmt) requiredPrivileges() []SQLPrivilege {
4✔
1642
        return []SQLPrivilege{SQLPrivilegeUpdate}
4✔
1643
}
4✔
1644

1645
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1646
        selectStmt := &SelectStmt{
1✔
1647
                ds:    stmt.tableRef,
1✔
1648
                where: stmt.where,
1✔
1649
        }
1✔
1650

1✔
1651
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1652
        if err != nil {
1✔
1653
                return err
×
1654
        }
×
1655

1656
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1657
        if err != nil {
1✔
1658
                return err
×
1659
        }
×
1660

1661
        for _, update := range stmt.updates {
2✔
1662
                col, err := table.GetColumnByName(update.col)
1✔
1663
                if err != nil {
1✔
1664
                        return err
×
1665
                }
×
1666

1667
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1668
                if err != nil {
1✔
1669
                        return err
×
1670
                }
×
1671
        }
1672

1673
        return nil
1✔
1674
}
1675

1676
func (stmt *UpdateStmt) validate(table *Table) error {
23✔
1677
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
23✔
1678

23✔
1679
        for _, update := range stmt.updates {
48✔
1680
                if update.op != EQ {
25✔
1681
                        return ErrIllegalArguments
×
1682
                }
×
1683

1684
                col, err := table.GetColumnByName(update.col)
25✔
1685
                if err != nil {
26✔
1686
                        return err
1✔
1687
                }
1✔
1688

1689
                if table.PrimaryIndex().IncludesCol(col.id) {
24✔
1690
                        return ErrPKCanNotBeUpdated
×
1691
                }
×
1692

1693
                _, duplicated := colIDs[col.id]
24✔
1694
                if duplicated {
24✔
1695
                        return ErrDuplicatedColumn
×
1696
                }
×
1697

1698
                colIDs[col.id] = struct{}{}
24✔
1699
        }
1700

1701
        return nil
22✔
1702
}
1703

1704
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
24✔
1705
        selectStmt := &SelectStmt{
24✔
1706
                ds:      stmt.tableRef,
24✔
1707
                where:   stmt.where,
24✔
1708
                indexOn: stmt.indexOn,
24✔
1709
                limit:   stmt.limit,
24✔
1710
                offset:  stmt.offset,
24✔
1711
        }
24✔
1712

24✔
1713
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
24✔
1714
        if err != nil {
25✔
1715
                return nil, err
1✔
1716
        }
1✔
1717
        defer rowReader.Close()
23✔
1718

23✔
1719
        table := rowReader.ScanSpecs().Index.table
23✔
1720

23✔
1721
        err = stmt.validate(table)
23✔
1722
        if err != nil {
24✔
1723
                return nil, err
1✔
1724
        }
1✔
1725

1726
        cols, err := rowReader.colsBySelector(ctx)
22✔
1727
        if err != nil {
22✔
1728
                return nil, err
×
1729
        }
×
1730

1731
        for {
78✔
1732
                row, err := rowReader.Read(ctx)
56✔
1733
                if errors.Is(err, ErrNoMoreRows) {
75✔
1734
                        break
19✔
1735
                } else if err != nil {
38✔
1736
                        return nil, err
1✔
1737
                }
1✔
1738

1739
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
36✔
1740

36✔
1741
                for _, col := range table.cols {
135✔
1742
                        encSel := EncodeSelector("", table.name, col.colName)
99✔
1743
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
99✔
1744
                }
99✔
1745

1746
                for _, update := range stmt.updates {
74✔
1747
                        col, err := table.GetColumnByName(update.col)
38✔
1748
                        if err != nil {
38✔
1749
                                return nil, err
×
1750
                        }
×
1751

1752
                        sval, err := update.val.substitute(params)
38✔
1753
                        if err != nil {
38✔
1754
                                return nil, err
×
1755
                        }
×
1756

1757
                        rval, err := sval.reduce(tx, row, table.name)
38✔
1758
                        if err != nil {
38✔
1759
                                return nil, err
×
1760
                        }
×
1761

1762
                        err = rval.requiresType(col.colType, cols, nil, table.name)
38✔
1763
                        if err != nil {
38✔
1764
                                return nil, err
×
1765
                        }
×
1766

1767
                        valuesByColID[col.id] = rval
38✔
1768
                }
1769

1770
                for i, col := range table.cols {
135✔
1771
                        v := valuesByColID[col.id]
99✔
1772

99✔
1773
                        row.ValuesByPosition[i] = v
99✔
1774
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
99✔
1775
                }
99✔
1776

1777
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
38✔
1778
                        return nil, err
2✔
1779
                }
2✔
1780

1781
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
34✔
1782
                if err != nil {
34✔
1783
                        return nil, err
×
1784
                }
×
1785

1786
                // primary index entry
1787
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
34✔
1788

34✔
1789
                // mkey must exist
34✔
1790
                _, err = tx.get(ctx, mkey)
34✔
1791
                if err != nil {
34✔
1792
                        return nil, err
×
1793
                }
×
1794

1795
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
34✔
1796
                if err != nil {
34✔
1797
                        return nil, err
×
1798
                }
×
1799
        }
1800

1801
        return tx, nil
19✔
1802
}
1803

1804
type DeleteFromStmt struct {
1805
        tableRef *tableRef
1806
        where    ValueExp
1807
        indexOn  []string
1808
        orderBy  []*OrdExp
1809
        limit    ValueExp
1810
        offset   ValueExp
1811
}
1812

1813
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdExp, limit ValueExp) *DeleteFromStmt {
4✔
1814
        return &DeleteFromStmt{
4✔
1815
                tableRef: NewTableRef(table, ""),
4✔
1816
                where:    where,
4✔
1817
                orderBy:  orderBy,
4✔
1818
                limit:    limit,
4✔
1819
        }
4✔
1820
}
4✔
1821

1822
func (stmt *DeleteFromStmt) readOnly() bool {
1✔
1823
        return false
1✔
1824
}
1✔
1825

1826
func (stmt *DeleteFromStmt) requiredPrivileges() []SQLPrivilege {
1✔
1827
        return []SQLPrivilege{SQLPrivilegeDelete}
1✔
1828
}
1✔
1829

1830
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1831
        selectStmt := &SelectStmt{
1✔
1832
                ds:      stmt.tableRef,
1✔
1833
                where:   stmt.where,
1✔
1834
                orderBy: stmt.orderBy,
1✔
1835
        }
1✔
1836
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1837
}
1✔
1838

1839
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
17✔
1840
        selectStmt := &SelectStmt{
17✔
1841
                ds:      stmt.tableRef,
17✔
1842
                where:   stmt.where,
17✔
1843
                indexOn: stmt.indexOn,
17✔
1844
                orderBy: stmt.orderBy,
17✔
1845
                limit:   stmt.limit,
17✔
1846
                offset:  stmt.offset,
17✔
1847
        }
17✔
1848

17✔
1849
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
17✔
1850
        if err != nil {
19✔
1851
                return nil, err
2✔
1852
        }
2✔
1853
        defer rowReader.Close()
15✔
1854

15✔
1855
        table := rowReader.ScanSpecs().Index.table
15✔
1856

15✔
1857
        for {
153✔
1858
                row, err := rowReader.Read(ctx)
138✔
1859
                if errors.Is(err, ErrNoMoreRows) {
152✔
1860
                        break
14✔
1861
                }
1862
                if err != nil {
125✔
1863
                        return nil, err
1✔
1864
                }
1✔
1865

1866
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
123✔
1867

123✔
1868
                for _, col := range table.cols {
413✔
1869
                        encSel := EncodeSelector("", table.name, col.colName)
290✔
1870
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
290✔
1871
                }
290✔
1872

1873
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
123✔
1874
                if err != nil {
123✔
1875
                        return nil, err
×
1876
                }
×
1877

1878
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
123✔
1879
                if err != nil {
123✔
1880
                        return nil, err
×
1881
                }
×
1882

1883
                tx.updatedRows++
123✔
1884
        }
1885
        return tx, nil
14✔
1886
}
1887

1888
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
123✔
1889
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
123✔
1890
        if err != nil {
123✔
1891
                return err
×
1892
        }
×
1893

1894
        for _, index := range table.indexes {
295✔
1895
                if !index.IsPrimary() {
221✔
1896
                        continue
49✔
1897
                }
1898

1899
                encodedValues := make([][]byte, 3+len(index.cols))
123✔
1900
                encodedValues[0] = EncodeID(DatabaseID)
123✔
1901
                encodedValues[1] = EncodeID(table.id)
123✔
1902
                encodedValues[2] = EncodeID(index.id)
123✔
1903

123✔
1904
                for i, col := range index.cols {
246✔
1905
                        val, specified := valuesByColID[col.id]
123✔
1906
                        if !specified {
123✔
1907
                                val = &NullValue{t: col.colType}
×
1908
                        }
×
1909

1910
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
123✔
1911

123✔
1912
                        encodedValues[i+3] = encVal
123✔
1913
                }
1914

1915
                md := store.NewKVMetadata()
123✔
1916

123✔
1917
                md.AsDeleted(true)
123✔
1918

123✔
1919
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
123✔
1920
                if err != nil {
123✔
1921
                        return err
×
1922
                }
×
1923
        }
1924

1925
        return nil
123✔
1926
}
1927

1928
type ValueExp interface {
1929
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1930
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1931
        substitute(params map[string]interface{}) (ValueExp, error)
1932
        selectors() []Selector
1933
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1934
        reduceSelectors(row *Row, implicitTable string) ValueExp
1935
        isConstant() bool
1936
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1937
        String() string
1938
}
1939

1940
type typedValueRange struct {
1941
        lRange *typedValueSemiRange
1942
        hRange *typedValueSemiRange
1943
}
1944

1945
type typedValueSemiRange struct {
1946
        val       TypedValue
1947
        inclusive bool
1948
}
1949

1950
func (r *typedValueRange) unitary() bool {
19✔
1951
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1952
        if r.lRange == nil || r.hRange == nil {
19✔
1953
                return false
×
1954
        }
×
1955

1956
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1957
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1958
}
1959

1960
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
5✔
1961
        if r.lRange == nil {
6✔
1962
                r.lRange = refiningRange.lRange
1✔
1963
        } else if r.lRange != nil && refiningRange.lRange != nil {
6✔
1964
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1965
                if err != nil {
1✔
1966
                        return err
×
1967
                }
×
1968
                r.lRange = maxRange
1✔
1969
        }
1970

1971
        if r.hRange == nil {
8✔
1972
                r.hRange = refiningRange.hRange
3✔
1973
        } else if r.hRange != nil && refiningRange.hRange != nil {
7✔
1974
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1975
                if err != nil {
2✔
1976
                        return err
×
1977
                }
×
1978
                r.hRange = minRange
2✔
1979
        }
1980

1981
        return nil
5✔
1982
}
1983

1984
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1985
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1986
                r.lRange = nil
2✔
1987
        } else {
5✔
1988
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1989
                if err != nil {
3✔
1990
                        return err
×
1991
                }
×
1992
                r.lRange = minRange
3✔
1993
        }
1994

1995
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1996
                r.hRange = nil
3✔
1997
        } else {
5✔
1998
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1999
                if err != nil {
2✔
2000
                        return err
×
2001
                }
×
2002
                r.hRange = maxRange
2✔
2003
        }
2004

2005
        return nil
5✔
2006
}
2007

2008
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
2009
        r, err := or1.val.Compare(or2.val)
3✔
2010
        if err != nil {
3✔
2011
                return nil, err
×
2012
        }
×
2013

2014
        maxVal := or1.val
3✔
2015
        if r < 0 {
5✔
2016
                maxVal = or2.val
2✔
2017
        }
2✔
2018

2019
        return &typedValueSemiRange{
3✔
2020
                val:       maxVal,
3✔
2021
                inclusive: or1.inclusive && or2.inclusive,
3✔
2022
        }, nil
3✔
2023
}
2024

2025
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
2026
        r, err := or1.val.Compare(or2.val)
5✔
2027
        if err != nil {
5✔
2028
                return nil, err
×
2029
        }
×
2030

2031
        minVal := or1.val
5✔
2032
        if r > 0 {
9✔
2033
                minVal = or2.val
4✔
2034
        }
4✔
2035

2036
        return &typedValueSemiRange{
5✔
2037
                val:       minVal,
5✔
2038
                inclusive: or1.inclusive || or2.inclusive,
5✔
2039
        }, nil
5✔
2040
}
2041

2042
type TypedValue interface {
2043
        ValueExp
2044
        Type() SQLValueType
2045
        RawValue() interface{}
2046
        Compare(val TypedValue) (int, error)
2047
        IsNull() bool
2048
}
2049

2050
type Tuple []TypedValue
2051

2052
func (t Tuple) Compare(other Tuple) (int, int, error) {
288,871✔
2053
        if len(t) != len(other) {
288,871✔
2054
                return -1, -1, ErrNotComparableValues
×
2055
        }
×
2056

2057
        for i := range t {
612,823✔
2058
                res, err := t[i].Compare(other[i])
323,952✔
2059
                if err != nil || res != 0 {
598,701✔
2060
                        return res, i, err
274,749✔
2061
                }
274,749✔
2062
        }
2063
        return 0, -1, nil
14,122✔
2064
}
2065

2066
func NewNull(t SQLValueType) *NullValue {
403✔
2067
        return &NullValue{t: t}
403✔
2068
}
403✔
2069

2070
type NullValue struct {
2071
        t SQLValueType
2072
}
2073

2074
func (n *NullValue) Type() SQLValueType {
128✔
2075
        return n.t
128✔
2076
}
128✔
2077

2078
func (n *NullValue) RawValue() interface{} {
406✔
2079
        return nil
406✔
2080
}
406✔
2081

2082
func (n *NullValue) IsNull() bool {
387✔
2083
        return true
387✔
2084
}
387✔
2085

2086
func (n *NullValue) String() string {
4✔
2087
        return "NULL"
4✔
2088
}
4✔
2089

2090
func (n *NullValue) Compare(val TypedValue) (int, error) {
111✔
2091
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
112✔
2092
                return 0, ErrNotComparableValues
1✔
2093
        }
1✔
2094

2095
        if val.RawValue() == nil {
156✔
2096
                return 0, nil
46✔
2097
        }
46✔
2098
        return -1, nil
64✔
2099
}
2100

2101
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
10✔
2102
        return v.t, nil
10✔
2103
}
10✔
2104

2105
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
12✔
2106
        if v.t == t {
19✔
2107
                return nil
7✔
2108
        }
7✔
2109

2110
        if v.t != AnyType {
6✔
2111
                return ErrInvalidTypes
1✔
2112
        }
1✔
2113

2114
        v.t = t
4✔
2115

4✔
2116
        return nil
4✔
2117
}
2118

2119
func (v *NullValue) selectors() []Selector {
17✔
2120
        return nil
17✔
2121
}
17✔
2122

2123
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
410✔
2124
        return v, nil
410✔
2125
}
410✔
2126

2127
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
373✔
2128
        return v, nil
373✔
2129
}
373✔
2130

2131
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
2132
        return v
10✔
2133
}
10✔
2134

2135
func (v *NullValue) isConstant() bool {
12✔
2136
        return true
12✔
2137
}
12✔
2138

2139
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2140
        return nil
1✔
2141
}
1✔
2142

2143
type Integer struct {
2144
        val int64
2145
}
2146

2147
func NewInteger(val int64) *Integer {
315✔
2148
        return &Integer{val: val}
315✔
2149
}
315✔
2150

2151
func (v *Integer) Type() SQLValueType {
413,933✔
2152
        return IntegerType
413,933✔
2153
}
413,933✔
2154

2155
func (v *Integer) IsNull() bool {
154,831✔
2156
        return false
154,831✔
2157
}
154,831✔
2158

2159
func (v *Integer) String() string {
54✔
2160
        return strconv.FormatInt(v.val, 10)
54✔
2161
}
54✔
2162

2163
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
99✔
2164
        return IntegerType, nil
99✔
2165
}
99✔
2166

2167
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
64✔
2168
        if t != IntegerType && t != JSONType {
68✔
2169
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
2170
        }
4✔
2171
        return nil
60✔
2172
}
2173

2174
func (v *Integer) selectors() []Selector {
54✔
2175
        return nil
54✔
2176
}
54✔
2177

2178
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
15,659✔
2179
        return v, nil
15,659✔
2180
}
15,659✔
2181

2182
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
17,861✔
2183
        return v, nil
17,861✔
2184
}
17,861✔
2185

2186
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
9✔
2187
        return v
9✔
2188
}
9✔
2189

2190
func (v *Integer) isConstant() bool {
125✔
2191
        return true
125✔
2192
}
125✔
2193

2194
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2195
        return nil
1✔
2196
}
1✔
2197

2198
func (v *Integer) RawValue() interface{} {
233,079✔
2199
        return v.val
233,079✔
2200
}
233,079✔
2201

2202
func (v *Integer) Compare(val TypedValue) (int, error) {
127,766✔
2203
        if val.IsNull() {
127,809✔
2204
                return 1, nil
43✔
2205
        }
43✔
2206

2207
        if val.Type() == JSONType {
127,724✔
2208
                res, err := val.Compare(v)
1✔
2209
                return -res, err
1✔
2210
        }
1✔
2211

2212
        if val.Type() == Float64Type {
127,722✔
2213
                r, err := val.Compare(v)
×
2214
                return r * -1, err
×
2215
        }
×
2216

2217
        if val.Type() != IntegerType {
127,729✔
2218
                return 0, ErrNotComparableValues
7✔
2219
        }
7✔
2220

2221
        rval := val.RawValue().(int64)
127,715✔
2222

127,715✔
2223
        if v.val == rval {
154,355✔
2224
                return 0, nil
26,640✔
2225
        }
26,640✔
2226

2227
        if v.val > rval {
146,873✔
2228
                return 1, nil
45,798✔
2229
        }
45,798✔
2230

2231
        return -1, nil
55,277✔
2232
}
2233

2234
type Timestamp struct {
2235
        val time.Time
2236
}
2237

2238
func (v *Timestamp) Type() SQLValueType {
55,297✔
2239
        return TimestampType
55,297✔
2240
}
55,297✔
2241

2242
func (v *Timestamp) IsNull() bool {
45,335✔
2243
        return false
45,335✔
2244
}
45,335✔
2245

2246
func (v *Timestamp) String() string {
1✔
2247
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2248
}
1✔
2249

2250
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2251
        return TimestampType, nil
1✔
2252
}
1✔
2253

2254
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
14✔
2255
        if t != TimestampType {
15✔
2256
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2257
        }
1✔
2258

2259
        return nil
13✔
2260
}
2261

2262
func (v *Timestamp) selectors() []Selector {
1✔
2263
        return nil
1✔
2264
}
1✔
2265

2266
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
1,396✔
2267
        return v, nil
1,396✔
2268
}
1,396✔
2269

2270
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,487✔
2271
        return v, nil
2,487✔
2272
}
2,487✔
2273

2274
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2275
        return v
1✔
2276
}
1✔
2277

2278
func (v *Timestamp) isConstant() bool {
1✔
2279
        return true
1✔
2280
}
1✔
2281

2282
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2283
        return nil
1✔
2284
}
1✔
2285

2286
func (v *Timestamp) RawValue() interface{} {
77,474✔
2287
        return v.val
77,474✔
2288
}
77,474✔
2289

2290
func (v *Timestamp) Compare(val TypedValue) (int, error) {
41,492✔
2291
        if val.IsNull() {
41,494✔
2292
                return 1, nil
2✔
2293
        }
2✔
2294

2295
        if val.Type() != TimestampType {
41,491✔
2296
                return 0, ErrNotComparableValues
1✔
2297
        }
1✔
2298

2299
        rval := val.RawValue().(time.Time)
41,489✔
2300

41,489✔
2301
        if v.val.Before(rval) {
59,205✔
2302
                return -1, nil
17,716✔
2303
        }
17,716✔
2304

2305
        if v.val.After(rval) {
47,353✔
2306
                return 1, nil
23,580✔
2307
        }
23,580✔
2308

2309
        return 0, nil
193✔
2310
}
2311

2312
type Varchar struct {
2313
        val string
2314
}
2315

2316
func NewVarchar(val string) *Varchar {
2,137✔
2317
        return &Varchar{val: val}
2,137✔
2318
}
2,137✔
2319

2320
func (v *Varchar) Type() SQLValueType {
181,345✔
2321
        return VarcharType
181,345✔
2322
}
181,345✔
2323

2324
func (v *Varchar) IsNull() bool {
90,236✔
2325
        return false
90,236✔
2326
}
90,236✔
2327

2328
func (v *Varchar) String() string {
18✔
2329
        return fmt.Sprintf("'%s'", v.val)
18✔
2330
}
18✔
2331

2332
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
62✔
2333
        return VarcharType, nil
62✔
2334
}
62✔
2335

2336
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
145✔
2337
        if t != VarcharType && t != JSONType {
149✔
2338
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
4✔
2339
        }
4✔
2340
        return nil
141✔
2341
}
2342

2343
func (v *Varchar) selectors() []Selector {
43✔
2344
        return nil
43✔
2345
}
43✔
2346

2347
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
5,201✔
2348
        return v, nil
5,201✔
2349
}
5,201✔
2350

2351
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
6,117✔
2352
        return v, nil
6,117✔
2353
}
6,117✔
2354

2355
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2356
        return v
×
2357
}
×
2358

2359
func (v *Varchar) isConstant() bool {
39✔
2360
        return true
39✔
2361
}
39✔
2362

2363
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2364
        return nil
1✔
2365
}
1✔
2366

2367
func (v *Varchar) RawValue() interface{} {
124,161✔
2368
        return v.val
124,161✔
2369
}
124,161✔
2370

2371
func (v *Varchar) Compare(val TypedValue) (int, error) {
83,049✔
2372
        if val.IsNull() {
83,095✔
2373
                return 1, nil
46✔
2374
        }
46✔
2375

2376
        if val.Type() == JSONType {
84,004✔
2377
                res, err := val.Compare(v)
1,001✔
2378
                return -res, err
1,001✔
2379
        }
1,001✔
2380

2381
        if val.Type() != VarcharType {
82,003✔
2382
                return 0, ErrNotComparableValues
1✔
2383
        }
1✔
2384

2385
        rval := val.RawValue().(string)
82,001✔
2386

82,001✔
2387
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
82,001✔
2388
}
2389

2390
type UUID struct {
2391
        val uuid.UUID
2392
}
2393

2394
func NewUUID(val uuid.UUID) *UUID {
1✔
2395
        return &UUID{val: val}
1✔
2396
}
1✔
2397

2398
func (v *UUID) Type() SQLValueType {
10✔
2399
        return UUIDType
10✔
2400
}
10✔
2401

2402
func (v *UUID) IsNull() bool {
26✔
2403
        return false
26✔
2404
}
26✔
2405

2406
func (v *UUID) String() string {
1✔
2407
        return v.val.String()
1✔
2408
}
1✔
2409

2410
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2411
        return UUIDType, nil
1✔
2412
}
1✔
2413

2414
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2415
        if t != UUIDType {
6✔
2416
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2417
        }
2✔
2418

2419
        return nil
2✔
2420
}
2421

2422
func (v *UUID) selectors() []Selector {
1✔
2423
        return nil
1✔
2424
}
1✔
2425

2426
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
6✔
2427
        return v, nil
6✔
2428
}
6✔
2429

2430
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5✔
2431
        return v, nil
5✔
2432
}
5✔
2433

2434
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2435
        return v
1✔
2436
}
1✔
2437

2438
func (v *UUID) isConstant() bool {
1✔
2439
        return true
1✔
2440
}
1✔
2441

2442
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2443
        return nil
1✔
2444
}
1✔
2445

2446
func (v *UUID) RawValue() interface{} {
41✔
2447
        return v.val
41✔
2448
}
41✔
2449

2450
func (v *UUID) Compare(val TypedValue) (int, error) {
5✔
2451
        if val.IsNull() {
7✔
2452
                return 1, nil
2✔
2453
        }
2✔
2454

2455
        if val.Type() != UUIDType {
4✔
2456
                return 0, ErrNotComparableValues
1✔
2457
        }
1✔
2458

2459
        rval := val.RawValue().(uuid.UUID)
2✔
2460

2✔
2461
        return bytes.Compare(v.val[:], rval[:]), nil
2✔
2462
}
2463

2464
type Bool struct {
2465
        val bool
2466
}
2467

2468
func NewBool(val bool) *Bool {
208✔
2469
        return &Bool{val: val}
208✔
2470
}
208✔
2471

2472
func (v *Bool) Type() SQLValueType {
1,962✔
2473
        return BooleanType
1,962✔
2474
}
1,962✔
2475

2476
func (v *Bool) IsNull() bool {
1,492✔
2477
        return false
1,492✔
2478
}
1,492✔
2479

2480
func (v *Bool) String() string {
41✔
2481
        return strconv.FormatBool(v.val)
41✔
2482
}
41✔
2483

2484
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
32✔
2485
        return BooleanType, nil
32✔
2486
}
32✔
2487

2488
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
58✔
2489
        if t != BooleanType && t != JSONType {
63✔
2490
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2491
        }
5✔
2492
        return nil
53✔
2493
}
2494

2495
func (v *Bool) selectors() []Selector {
5✔
2496
        return nil
5✔
2497
}
5✔
2498

2499
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
648✔
2500
        return v, nil
648✔
2501
}
648✔
2502

2503
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
735✔
2504
        return v, nil
735✔
2505
}
735✔
2506

2507
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2508
        return v
×
2509
}
×
2510

2511
func (v *Bool) isConstant() bool {
3✔
2512
        return true
3✔
2513
}
3✔
2514

2515
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
2516
        return nil
8✔
2517
}
8✔
2518

2519
func (v *Bool) RawValue() interface{} {
1,731✔
2520
        return v.val
1,731✔
2521
}
1,731✔
2522

2523
func (v *Bool) Compare(val TypedValue) (int, error) {
570✔
2524
        if val.IsNull() {
600✔
2525
                return 1, nil
30✔
2526
        }
30✔
2527

2528
        if val.Type() == JSONType {
541✔
2529
                res, err := val.Compare(v)
1✔
2530
                return -res, err
1✔
2531
        }
1✔
2532

2533
        if val.Type() != BooleanType {
539✔
2534
                return 0, ErrNotComparableValues
×
2535
        }
×
2536

2537
        rval := val.RawValue().(bool)
539✔
2538

539✔
2539
        if v.val == rval {
871✔
2540
                return 0, nil
332✔
2541
        }
332✔
2542

2543
        if v.val {
213✔
2544
                return 1, nil
6✔
2545
        }
6✔
2546

2547
        return -1, nil
201✔
2548
}
2549

2550
type Blob struct {
2551
        val []byte
2552
}
2553

2554
func NewBlob(val []byte) *Blob {
286✔
2555
        return &Blob{val: val}
286✔
2556
}
286✔
2557

2558
func (v *Blob) Type() SQLValueType {
53✔
2559
        return BLOBType
53✔
2560
}
53✔
2561

2562
func (v *Blob) IsNull() bool {
2,312✔
2563
        return false
2,312✔
2564
}
2,312✔
2565

2566
func (v *Blob) String() string {
2✔
2567
        return hex.EncodeToString(v.val)
2✔
2568
}
2✔
2569

2570
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2571
        return BLOBType, nil
1✔
2572
}
1✔
2573

2574
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2575
        if t != BLOBType {
3✔
2576
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2577
        }
1✔
2578

2579
        return nil
1✔
2580
}
2581

2582
func (v *Blob) selectors() []Selector {
1✔
2583
        return nil
1✔
2584
}
1✔
2585

2586
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
714✔
2587
        return v, nil
714✔
2588
}
714✔
2589

2590
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2591
        return v, nil
726✔
2592
}
726✔
2593

2594
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2595
        return v
×
2596
}
×
2597

2598
func (v *Blob) isConstant() bool {
7✔
2599
        return true
7✔
2600
}
7✔
2601

2602
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2603
        return nil
×
2604
}
×
2605

2606
func (v *Blob) RawValue() interface{} {
2,570✔
2607
        return v.val
2,570✔
2608
}
2,570✔
2609

2610
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
2611
        if val.IsNull() {
25✔
2612
                return 1, nil
×
2613
        }
×
2614

2615
        if val.Type() != BLOBType {
25✔
2616
                return 0, ErrNotComparableValues
×
2617
        }
×
2618

2619
        rval := val.RawValue().([]byte)
25✔
2620

25✔
2621
        return bytes.Compare(v.val, rval), nil
25✔
2622
}
2623

2624
type Float64 struct {
2625
        val float64
2626
}
2627

2628
func NewFloat64(val float64) *Float64 {
465✔
2629
        return &Float64{val: val}
465✔
2630
}
465✔
2631

2632
func (v *Float64) Type() SQLValueType {
292,430✔
2633
        return Float64Type
292,430✔
2634
}
292,430✔
2635

2636
func (v *Float64) IsNull() bool {
6,593✔
2637
        return false
6,593✔
2638
}
6,593✔
2639

2640
func (v *Float64) String() string {
3✔
2641
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2642
}
3✔
2643

2644
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2645
        return Float64Type, nil
15✔
2646
}
15✔
2647

2648
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2649
        if t != Float64Type && t != JSONType {
21✔
2650
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2651
        }
1✔
2652
        return nil
19✔
2653
}
2654

2655
func (v *Float64) selectors() []Selector {
3✔
2656
        return nil
3✔
2657
}
3✔
2658

2659
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
2,280✔
2660
        return v, nil
2,280✔
2661
}
2,280✔
2662

2663
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
4,353✔
2664
        return v, nil
4,353✔
2665
}
4,353✔
2666

2667
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2668
        return v
1✔
2669
}
1✔
2670

2671
func (v *Float64) isConstant() bool {
5✔
2672
        return true
5✔
2673
}
5✔
2674

2675
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2676
        return nil
1✔
2677
}
1✔
2678

2679
func (v *Float64) RawValue() interface{} {
515,138✔
2680
        return v.val
515,138✔
2681
}
515,138✔
2682

2683
func (v *Float64) Compare(val TypedValue) (int, error) {
87,398✔
2684
        if val.Type() == JSONType {
87,399✔
2685
                res, err := val.Compare(v)
1✔
2686
                return -res, err
1✔
2687
        }
1✔
2688

2689
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
87,397✔
2690
        if err != nil {
87,398✔
2691
                return 0, err
1✔
2692
        }
1✔
2693

2694
        if convVal == nil {
87,399✔
2695
                return 1, nil
3✔
2696
        }
3✔
2697

2698
        rval, ok := convVal.(float64)
87,393✔
2699
        if !ok {
87,393✔
2700
                return 0, ErrNotComparableValues
×
2701
        }
×
2702

2703
        if v.val == rval {
87,517✔
2704
                return 0, nil
124✔
2705
        }
124✔
2706

2707
        if v.val > rval {
126,162✔
2708
                return 1, nil
38,893✔
2709
        }
38,893✔
2710

2711
        return -1, nil
48,376✔
2712
}
2713

2714
type FnCall struct {
2715
        fn     string
2716
        params []ValueExp
2717
}
2718

2719
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
24✔
2720
        fn, err := v.resolveFunc()
24✔
2721
        if err != nil {
25✔
2722
                return AnyType, nil
1✔
2723
        }
1✔
2724
        return fn.InferType(cols, params, implicitTable)
23✔
2725
}
2726

2727
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2728
        fn, err := v.resolveFunc()
20✔
2729
        if err != nil {
21✔
2730
                return err
1✔
2731
        }
1✔
2732
        return fn.RequiresType(t, cols, params, implicitTable)
19✔
2733
}
2734

2735
func (v *FnCall) selectors() []Selector {
38✔
2736
        selectors := make([]Selector, 0)
38✔
2737
        for _, param := range v.params {
106✔
2738
                selectors = append(selectors, param.selectors()...)
68✔
2739
        }
68✔
2740
        return selectors
38✔
2741
}
2742

2743
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
438✔
2744
        ps := make([]ValueExp, len(v.params))
438✔
2745
        for i, p := range v.params {
808✔
2746
                ps[i], err = p.substitute(params)
370✔
2747
                if err != nil {
370✔
2748
                        return nil, err
×
2749
                }
×
2750
        }
2751

2752
        return &FnCall{
438✔
2753
                fn:     v.fn,
438✔
2754
                params: ps,
438✔
2755
        }, nil
438✔
2756
}
2757

2758
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
438✔
2759
        fn, err := v.resolveFunc()
438✔
2760
        if err != nil {
439✔
2761
                return nil, err
1✔
2762
        }
1✔
2763

2764
        fnInputs, err := v.reduceParams(tx, row, implicitTable)
437✔
2765
        if err != nil {
437✔
2766
                return nil, err
×
2767
        }
×
2768
        return fn.Apply(tx, fnInputs)
437✔
2769
}
2770

2771
func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) {
437✔
2772
        var values []TypedValue
437✔
2773
        if len(v.params) > 0 {
773✔
2774
                values = make([]TypedValue, len(v.params))
336✔
2775
                for i, p := range v.params {
706✔
2776
                        v, err := p.reduce(tx, row, implicitTable)
370✔
2777
                        if err != nil {
370✔
2778
                                return nil, err
×
2779
                        }
×
2780
                        values[i] = v
370✔
2781
                }
2782
        }
2783
        return values, nil
437✔
2784
}
2785

2786
func (v *FnCall) resolveFunc() (Function, error) {
482✔
2787
        fn, exists := builtinFunctions[strings.ToUpper(v.fn)]
482✔
2788
        if !exists {
485✔
2789
                return nil, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
3✔
2790
        }
3✔
2791
        return fn, nil
479✔
2792
}
2793

2794
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2795
        return v
×
2796
}
×
2797

2798
func (v *FnCall) isConstant() bool {
13✔
2799
        return false
13✔
2800
}
13✔
2801

2802
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2803
        return nil
×
2804
}
×
2805

2806
func (v *FnCall) String() string {
1✔
2807
        params := make([]string, len(v.params))
1✔
2808
        for i, p := range v.params {
4✔
2809
                params[i] = p.String()
3✔
2810
        }
3✔
2811
        return v.fn + "(" + strings.Join(params, ",") + ")"
1✔
2812
}
2813

2814
type Cast struct {
2815
        val ValueExp
2816
        t   SQLValueType
2817
}
2818

2819
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2820
        _, err := c.val.inferType(cols, params, implicitTable)
23✔
2821
        if err != nil {
24✔
2822
                return AnyType, err
1✔
2823
        }
1✔
2824

2825
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2826

2827
        return c.t, nil
22✔
2828
}
2829

2830
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2831
        if c.t != t {
×
2832
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2833
        }
×
2834

2835
        return nil
×
2836
}
2837

2838
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
281✔
2839
        val, err := c.val.substitute(params)
281✔
2840
        if err != nil {
281✔
2841
                return nil, err
×
2842
        }
×
2843
        c.val = val
281✔
2844
        return c, nil
281✔
2845
}
2846

2847
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
269✔
2848
        val, err := c.val.reduce(tx, row, implicitTable)
269✔
2849
        if err != nil {
269✔
2850
                return nil, err
×
2851
        }
×
2852

2853
        conv, err := getConverter(val.Type(), c.t)
269✔
2854
        if conv == nil {
272✔
2855
                return nil, err
3✔
2856
        }
3✔
2857

2858
        return conv(val)
266✔
2859
}
2860

2861
func (v *Cast) selectors() []Selector {
6✔
2862
        return v.val.selectors()
6✔
2863
}
6✔
2864

2865
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2866
        return &Cast{
×
2867
                val: c.val.reduceSelectors(row, implicitTable),
×
2868
                t:   c.t,
×
2869
        }
×
2870
}
×
2871

2872
func (c *Cast) isConstant() bool {
7✔
2873
        return c.val.isConstant()
7✔
2874
}
7✔
2875

2876
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2877
        return nil
×
2878
}
×
2879

2880
func (c *Cast) String() string {
1✔
2881
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2882
}
1✔
2883

2884
type Param struct {
2885
        id  string
2886
        pos int
2887
}
2888

2889
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
58✔
2890
        t, ok := params[v.id]
58✔
2891
        if !ok {
114✔
2892
                params[v.id] = AnyType
56✔
2893
                return AnyType, nil
56✔
2894
        }
56✔
2895

2896
        return t, nil
2✔
2897
}
2898

2899
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2900
        currT, ok := params[v.id]
76✔
2901
        if ok && currT != t && currT != AnyType {
80✔
2902
                return ErrInferredMultipleTypes
4✔
2903
        }
4✔
2904

2905
        params[v.id] = t
72✔
2906

72✔
2907
        return nil
72✔
2908
}
2909

2910
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
7,795✔
2911
        val, ok := params[p.id]
7,795✔
2912
        if !ok {
7,857✔
2913
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2914
        }
62✔
2915

2916
        if val == nil {
7,790✔
2917
                return &NullValue{t: AnyType}, nil
57✔
2918
        }
57✔
2919

2920
        switch v := val.(type) {
7,676✔
2921
        case bool:
96✔
2922
                {
192✔
2923
                        return &Bool{val: v}, nil
96✔
2924
                }
96✔
2925
        case string:
1,983✔
2926
                {
3,966✔
2927
                        return &Varchar{val: v}, nil
1,983✔
2928
                }
1,983✔
2929
        case int:
2,142✔
2930
                {
4,284✔
2931
                        return &Integer{val: int64(v)}, nil
2,142✔
2932
                }
2,142✔
2933
        case uint:
×
2934
                {
×
2935
                        return &Integer{val: int64(v)}, nil
×
2936
                }
×
2937
        case uint64:
34✔
2938
                {
68✔
2939
                        return &Integer{val: int64(v)}, nil
34✔
2940
                }
34✔
2941
        case int64:
227✔
2942
                {
454✔
2943
                        return &Integer{val: v}, nil
227✔
2944
                }
227✔
2945
        case []byte:
14✔
2946
                {
28✔
2947
                        return &Blob{val: v}, nil
14✔
2948
                }
14✔
2949
        case time.Time:
1,092✔
2950
                {
2,184✔
2951
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
1,092✔
2952
                }
1,092✔
2953
        case float64:
2,087✔
2954
                {
4,174✔
2955
                        return &Float64{val: v}, nil
2,087✔
2956
                }
2,087✔
2957
        }
2958
        return nil, ErrUnsupportedParameter
1✔
2959
}
2960

2961
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2962
        return nil, ErrUnexpected
×
2963
}
×
2964

2965
func (p *Param) selectors() []Selector {
4✔
2966
        return nil
4✔
2967
}
4✔
2968

2969
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2970
        return p
×
2971
}
×
2972

2973
func (p *Param) isConstant() bool {
130✔
2974
        return true
130✔
2975
}
130✔
2976

2977
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
6✔
2978
        return nil
6✔
2979
}
6✔
2980

2981
func (v *Param) String() string {
2✔
2982
        return "@" + v.id
2✔
2983
}
2✔
2984

2985
type whenThenClause struct {
2986
        when, then ValueExp
2987
}
2988

2989
type CaseWhenExp struct {
2990
        exp      ValueExp
2991
        whenThen []whenThenClause
2992
        elseExp  ValueExp
2993
}
2994

2995
func (ce *CaseWhenExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
8✔
2996
        checkType := func(e ValueExp, expectedType SQLValueType) (string, error) {
18✔
2997
                t, err := e.inferType(cols, params, implicitTable)
10✔
2998
                if err != nil {
10✔
2999
                        return "", err
×
3000
                }
×
3001

3002
                if expectedType == AnyType {
15✔
3003
                        return t, nil
5✔
3004
                }
5✔
3005

3006
                if t != expectedType {
6✔
3007
                        if (t == Float64Type && expectedType == IntegerType) ||
1✔
3008
                                (t == IntegerType && expectedType == Float64Type) {
1✔
3009
                                return Float64Type, nil
×
3010
                        }
×
3011
                        return "", fmt.Errorf("%w: CASE types %s and %s cannot be matched", ErrInferredMultipleTypes, expectedType, t)
1✔
3012
                }
3013
                return t, nil
4✔
3014
        }
3015

3016
        searchType := BooleanType
8✔
3017
        inferredResType := AnyType
8✔
3018
        if ce.exp != nil {
11✔
3019
                t, err := ce.exp.inferType(cols, params, implicitTable)
3✔
3020
                if err != nil {
3✔
3021
                        return "", err
×
3022
                }
×
3023
                searchType = t
3✔
3024
        }
3025

3026
        for _, e := range ce.whenThen {
16✔
3027
                whenType, err := e.when.inferType(cols, params, implicitTable)
8✔
3028
                if err != nil {
8✔
3029
                        return "", err
×
3030
                }
×
3031

3032
                if whenType != searchType {
11✔
3033
                        return "", fmt.Errorf("%w: argument of CASE/WHEN must be of type %s, not type %s", ErrInvalidTypes, searchType, whenType)
3✔
3034
                }
3✔
3035

3036
                t, err := checkType(e.then, inferredResType)
5✔
3037
                if err != nil {
5✔
3038
                        return "", err
×
3039
                }
×
3040
                inferredResType = t
5✔
3041
        }
3042

3043
        if ce.elseExp != nil {
10✔
3044
                return checkType(ce.elseExp, inferredResType)
5✔
3045
        }
5✔
3046
        return inferredResType, nil
×
3047
}
3048

3049
func (ce *CaseWhenExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
3050
        inferredType, err := ce.inferType(cols, params, implicitTable)
2✔
3051
        if err != nil {
3✔
3052
                return err
1✔
3053
        }
1✔
3054

3055
        if inferredType != t {
1✔
3056
                return fmt.Errorf("%w: expected type %s but %s found instead", ErrInvalidTypes, t, inferredType)
×
3057
        }
×
3058
        return nil
1✔
3059
}
3060

3061
func (ce *CaseWhenExp) substitute(params map[string]interface{}) (ValueExp, error) {
504✔
3062
        var exp ValueExp
504✔
3063
        if ce.exp != nil {
605✔
3064
                e, err := ce.exp.substitute(params)
101✔
3065
                if err != nil {
101✔
3066
                        return nil, err
×
3067
                }
×
3068
                exp = e
101✔
3069
        }
3070

3071
        whenThen := make([]whenThenClause, len(ce.whenThen))
504✔
3072
        for i, wt := range ce.whenThen {
1,208✔
3073
                whenValue, err := wt.when.substitute(params)
704✔
3074
                if err != nil {
704✔
3075
                        return nil, err
×
3076
                }
×
3077
                whenThen[i].when = whenValue
704✔
3078

704✔
3079
                thenValue, err := wt.then.substitute(params)
704✔
3080
                if err != nil {
704✔
3081
                        return nil, err
×
3082
                }
×
3083
                whenThen[i].then = thenValue
704✔
3084
        }
3085

3086
        if ce.elseExp == nil {
506✔
3087
                return &CaseWhenExp{
2✔
3088
                        exp:      exp,
2✔
3089
                        whenThen: whenThen,
2✔
3090
                }, nil
2✔
3091
        }
2✔
3092

3093
        elseValue, err := ce.elseExp.substitute(params)
502✔
3094
        if err != nil {
502✔
3095
                return nil, err
×
3096
        }
×
3097
        return &CaseWhenExp{
502✔
3098
                exp:      exp,
502✔
3099
                whenThen: whenThen,
502✔
3100
                elseExp:  elseValue,
502✔
3101
        }, nil
502✔
3102
}
3103

3104
func (ce *CaseWhenExp) selectors() []Selector {
7✔
3105
        selectors := make([]Selector, 0)
7✔
3106
        if ce.exp != nil {
8✔
3107
                selectors = append(selectors, ce.exp.selectors()...)
1✔
3108
        }
1✔
3109

3110
        for _, wh := range ce.whenThen {
16✔
3111
                selectors = append(selectors, wh.when.selectors()...)
9✔
3112
                selectors = append(selectors, wh.then.selectors()...)
9✔
3113
        }
9✔
3114

3115
        if ce.elseExp == nil {
9✔
3116
                return selectors
2✔
3117
        }
2✔
3118
        return append(selectors, ce.elseExp.selectors()...)
5✔
3119
}
3120

3121
func (ce *CaseWhenExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
304✔
3122
        var searchValue TypedValue
304✔
3123
        if ce.exp != nil {
405✔
3124
                v, err := ce.exp.reduce(tx, row, implicitTable)
101✔
3125
                if err != nil {
101✔
3126
                        return nil, err
×
3127
                }
×
3128
                searchValue = v
101✔
3129
        } else {
203✔
3130
                searchValue = &Bool{val: true}
203✔
3131
        }
203✔
3132

3133
        for _, wt := range ce.whenThen {
728✔
3134
                v, err := wt.when.reduce(tx, row, implicitTable)
424✔
3135
                if err != nil {
424✔
3136
                        return nil, err
×
3137
                }
×
3138

3139
                if v.Type() != searchValue.Type() {
425✔
3140
                        return nil, fmt.Errorf("%w: argument of CASE/WHEN must be type %s, not type %s", ErrInvalidTypes, v.Type(), searchValue.Type())
1✔
3141
                }
1✔
3142

3143
                res, err := v.Compare(searchValue)
423✔
3144
                if err != nil {
423✔
3145
                        return nil, err
×
3146
                }
×
3147
                if res == 0 {
616✔
3148
                        return wt.then.reduce(tx, row, implicitTable)
193✔
3149
                }
193✔
3150
        }
3151

3152
        if ce.elseExp == nil {
111✔
3153
                return NewNull(AnyType), nil
1✔
3154
        }
1✔
3155
        return ce.elseExp.reduce(tx, row, implicitTable)
109✔
3156
}
3157

3158
func (ce *CaseWhenExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3159
        whenThen := make([]whenThenClause, len(ce.whenThen))
1✔
3160
        for i, wt := range ce.whenThen {
2✔
3161
                whenValue := wt.when.reduceSelectors(row, implicitTable)
1✔
3162
                whenThen[i].when = whenValue
1✔
3163

1✔
3164
                thenValue := wt.then.reduceSelectors(row, implicitTable)
1✔
3165
                whenThen[i].then = thenValue
1✔
3166
        }
1✔
3167

3168
        if ce.elseExp == nil {
1✔
3169
                return &CaseWhenExp{
×
3170
                        whenThen: whenThen,
×
3171
                }
×
3172
        }
×
3173

3174
        return &CaseWhenExp{
1✔
3175
                whenThen: whenThen,
1✔
3176
                elseExp:  ce.elseExp.reduceSelectors(row, implicitTable),
1✔
3177
        }
1✔
3178
}
3179

3180
func (ce *CaseWhenExp) isConstant() bool {
1✔
3181
        return false
1✔
3182
}
1✔
3183

3184
func (ce *CaseWhenExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
3185
        return nil
1✔
3186
}
1✔
3187

3188
func (ce *CaseWhenExp) String() string {
3✔
3189
        var sb strings.Builder
3✔
3190
        for _, wh := range ce.whenThen {
7✔
3191
                sb.WriteString(fmt.Sprintf("WHEN %s THEN %s ", wh.when.String(), wh.then.String()))
4✔
3192
        }
4✔
3193

3194
        if ce.elseExp != nil {
5✔
3195
                sb.WriteString("ELSE " + ce.elseExp.String() + " ")
2✔
3196
        }
2✔
3197
        return "CASE " + sb.String() + "END"
3✔
3198
}
3199

3200
type Comparison int
3201

3202
const (
3203
        EqualTo Comparison = iota
3204
        LowerThan
3205
        LowerOrEqualTo
3206
        GreaterThan
3207
        GreaterOrEqualTo
3208
)
3209

3210
type DataSource interface {
3211
        SQLStmt
3212
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
3213
        Alias() string
3214
}
3215

3216
type TargetEntry struct {
3217
        Exp ValueExp
3218
        As  string
3219
}
3220

3221
type SelectStmt struct {
3222
        distinct  bool
3223
        targets   []TargetEntry
3224
        selectors []Selector
3225
        ds        DataSource
3226
        indexOn   []string
3227
        joins     []*JoinSpec
3228
        where     ValueExp
3229
        groupBy   []*ColSelector
3230
        having    ValueExp
3231
        orderBy   []*OrdExp
3232
        limit     ValueExp
3233
        offset    ValueExp
3234
        as        string
3235
}
3236

3237
func NewSelectStmt(
3238
        targets []TargetEntry,
3239
        ds DataSource,
3240
        where ValueExp,
3241
        orderBy []*OrdExp,
3242
        limit ValueExp,
3243
        offset ValueExp,
3244
) *SelectStmt {
71✔
3245
        return &SelectStmt{
71✔
3246
                targets: targets,
71✔
3247
                ds:      ds,
71✔
3248
                where:   where,
71✔
3249
                orderBy: orderBy,
71✔
3250
                limit:   limit,
71✔
3251
                offset:  offset,
71✔
3252
        }
71✔
3253
}
71✔
3254

3255
func (stmt *SelectStmt) readOnly() bool {
94✔
3256
        return true
94✔
3257
}
94✔
3258

3259
func (stmt *SelectStmt) requiredPrivileges() []SQLPrivilege {
96✔
3260
        return []SQLPrivilege{SQLPrivilegeSelect}
96✔
3261
}
96✔
3262

3263
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
53✔
3264
        _, err := stmt.execAt(ctx, tx, nil)
53✔
3265
        if err != nil {
53✔
3266
                return err
×
3267
        }
×
3268

3269
        // TODO: (jeroiraz) may be optimized so to resolve the query statement just once
3270
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
53✔
3271
        if err != nil {
54✔
3272
                return err
1✔
3273
        }
1✔
3274
        defer rowReader.Close()
52✔
3275

52✔
3276
        return rowReader.InferParameters(ctx, params)
52✔
3277
}
3278

3279
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
638✔
3280
        if stmt.groupBy == nil && stmt.having != nil {
639✔
3281
                return nil, ErrHavingClauseRequiresGroupClause
1✔
3282
        }
1✔
3283

3284
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
722✔
3285
                for _, sel := range stmt.targetSelectors() {
227✔
3286
                        _, isAgg := sel.(*AggColSelector)
142✔
3287
                        if !isAgg && !stmt.groupByContains(sel) {
144✔
3288
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3289
                        }
2✔
3290
                }
3291
        }
3292

3293
        if len(stmt.orderBy) > 0 {
792✔
3294
                for _, col := range stmt.orderBy {
352✔
3295
                        for _, sel := range col.exp.selectors() {
380✔
3296
                                _, isAgg := sel.(*AggColSelector)
185✔
3297
                                if (isAgg && !stmt.selectorAppearsInTargets(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
187✔
3298
                                        return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3299
                                }
2✔
3300
                        }
3301
                }
3302
        }
3303
        return tx, nil
633✔
3304
}
3305

3306
func (stmt *SelectStmt) targetSelectors() []Selector {
2,571✔
3307
        if stmt.selectors == nil {
3,544✔
3308
                stmt.selectors = stmt.extractSelectors()
973✔
3309
        }
973✔
3310
        return stmt.selectors
2,571✔
3311
}
3312

3313
func (stmt *SelectStmt) selectorAppearsInTargets(s Selector) bool {
4✔
3314
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
3315

4✔
3316
        for _, sel := range stmt.targetSelectors() {
12✔
3317
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
11✔
3318
                        return true
3✔
3319
                }
3✔
3320
        }
3321
        return false
1✔
3322
}
3323

3324
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
3325
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
3326

57✔
3327
        for _, colSel := range stmt.groupBy {
137✔
3328
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
134✔
3329
                        return true
54✔
3330
                }
54✔
3331
        }
3332
        return false
3✔
3333
}
3334

3335
func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector {
80✔
3336
        cols := make([]*AggColSelector, 0, len(stmt.targets))
80✔
3337

80✔
3338
        for _, t := range stmt.targets {
214✔
3339
                selectors := t.Exp.selectors()
134✔
3340
                for _, sel := range selectors {
268✔
3341
                        aggSel, isAgg := sel.(*AggColSelector)
134✔
3342
                        if isAgg {
244✔
3343
                                cols = append(cols, aggSel)
110✔
3344
                        }
110✔
3345
                }
3346
        }
3347
        return cols
80✔
3348
}
3349

3350
func (stmt *SelectStmt) extractSelectors() []Selector {
973✔
3351
        selectors := make([]Selector, 0, len(stmt.targets))
973✔
3352
        for _, t := range stmt.targets {
1,872✔
3353
                selectors = append(selectors, t.Exp.selectors()...)
899✔
3354
        }
899✔
3355
        return selectors
973✔
3356
}
3357

3358
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
980✔
3359
        scanSpecs, err := stmt.genScanSpecs(tx, params)
980✔
3360
        if err != nil {
996✔
3361
                return nil, err
16✔
3362
        }
16✔
3363

3364
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
964✔
3365
        if err != nil {
970✔
3366
                return nil, err
6✔
3367
        }
6✔
3368
        defer func() {
1,916✔
3369
                if err != nil {
965✔
3370
                        rowReader.Close()
7✔
3371
                }
7✔
3372
        }()
3373

3374
        if stmt.joins != nil {
973✔
3375
                var jointRowReader *jointRowReader
15✔
3376
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
15✔
3377
                if err != nil {
16✔
3378
                        return nil, err
1✔
3379
                }
1✔
3380
                rowReader = jointRowReader
14✔
3381
        }
3382

3383
        if stmt.where != nil {
1,491✔
3384
                rowReader = newConditionalRowReader(rowReader, stmt.where)
534✔
3385
        }
534✔
3386

3387
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
1,037✔
3388
                if len(scanSpecs.groupBySortExps) > 0 {
92✔
3389
                        var sortRowReader *sortRowReader
12✔
3390
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortExps)
12✔
3391
                        if err != nil {
12✔
3392
                                return nil, err
×
3393
                        }
×
3394
                        rowReader = sortRowReader
12✔
3395
                }
3396

3397
                var groupedRowReader *groupedRowReader
80✔
3398
                groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy)
80✔
3399
                if err != nil {
82✔
3400
                        return nil, err
2✔
3401
                }
2✔
3402
                rowReader = groupedRowReader
78✔
3403

78✔
3404
                if stmt.having != nil {
82✔
3405
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
3406
                }
4✔
3407
        }
3408

3409
        if len(scanSpecs.orderBySortExps) > 0 {
1,004✔
3410
                var sortRowReader *sortRowReader
49✔
3411
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
49✔
3412
                if err != nil {
50✔
3413
                        return nil, err
1✔
3414
                }
1✔
3415
                rowReader = sortRowReader
48✔
3416
        }
3417

3418
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets)
954✔
3419
        if err != nil {
957✔
3420
                return nil, err
3✔
3421
        }
3✔
3422
        rowReader = projectedRowReader
951✔
3423

951✔
3424
        if stmt.distinct {
958✔
3425
                var distinctRowReader *distinctRowReader
7✔
3426
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
7✔
3427
                if err != nil {
7✔
3428
                        return nil, err
×
3429
                }
×
3430
                rowReader = distinctRowReader
7✔
3431
        }
3432

3433
        if stmt.offset != nil {
1,001✔
3434
                var offset int
50✔
3435
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
3436
                if err != nil {
50✔
3437
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
3438
                }
×
3439

3440
                rowReader = newOffsetRowReader(rowReader, offset)
50✔
3441
        }
3442

3443
        if stmt.limit != nil {
1,048✔
3444
                var limit int
97✔
3445
                limit, err = evalExpAsInt(tx, stmt.limit, params)
97✔
3446
                if err != nil {
97✔
3447
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
3448
                }
×
3449

3450
                if limit < 0 {
97✔
3451
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
3452
                }
×
3453

3454
                if limit > 0 {
140✔
3455
                        rowReader = newLimitRowReader(rowReader, limit)
43✔
3456
                }
43✔
3457
        }
3458
        return rowReader, nil
951✔
3459
}
3460

3461
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
963✔
3462
        if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
969✔
3463
                if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
8✔
3464
                        return orderByExps, nil
2✔
3465
                }
2✔
3466

3467
                if ordExpsHasPrefix(groupByCols, orderByExps, stmt.Alias()) {
5✔
3468
                        for i := range orderByExps {
2✔
3469
                                groupByCols[i].descOrder = orderByExps[i].descOrder
1✔
3470
                        }
1✔
3471
                        return groupByCols, nil
1✔
3472
                }
3473
        }
3474
        return groupByCols, orderByExps
960✔
3475
}
3476

3477
func ordExpsHasPrefix(cols, prefix []*OrdExp, table string) bool {
10✔
3478
        if len(prefix) > len(cols) {
12✔
3479
                return false
2✔
3480
        }
2✔
3481

3482
        for i := range prefix {
17✔
3483
                ls := prefix[i].AsSelector()
9✔
3484
                rs := cols[i].AsSelector()
9✔
3485

9✔
3486
                if ls == nil || rs == nil {
9✔
3487
                        return false
×
3488
                }
×
3489

3490
                if EncodeSelector(ls.resolve(table)) != EncodeSelector(rs.resolve(table)) {
14✔
3491
                        return false
5✔
3492
                }
5✔
3493
        }
3494
        return true
3✔
3495
}
3496

3497
func (stmt *SelectStmt) groupByOrdExps() []*OrdExp {
980✔
3498
        groupByCols := stmt.groupBy
980✔
3499

980✔
3500
        ordExps := make([]*OrdExp, 0, len(groupByCols))
980✔
3501
        for _, col := range groupByCols {
1,027✔
3502
                ordExps = append(ordExps, &OrdExp{exp: col})
47✔
3503
        }
47✔
3504
        return ordExps
980✔
3505
}
3506

3507
func ordExpsHaveAggregations(exps []*OrdExp) bool {
7✔
3508
        for _, e := range exps {
17✔
3509
                if _, isAgg := e.exp.(*AggColSelector); isAgg {
11✔
3510
                        return true
1✔
3511
                }
1✔
3512
        }
3513
        return false
6✔
3514
}
3515

3516
func (stmt *SelectStmt) containsAggregations() bool {
1,594✔
3517
        for _, sel := range stmt.targetSelectors() {
3,212✔
3518
                _, isAgg := sel.(*AggColSelector)
1,618✔
3519
                if isAgg {
1,781✔
3520
                        return true
163✔
3521
                }
163✔
3522
        }
3523
        return false
1,431✔
3524
}
3525

3526
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
147✔
3527
        offset, err := exp.substitute(params)
147✔
3528
        if err != nil {
147✔
3529
                return 0, err
×
3530
        }
×
3531

3532
        texp, err := offset.reduce(tx, nil, "")
147✔
3533
        if err != nil {
147✔
3534
                return 0, err
×
3535
        }
×
3536

3537
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
147✔
3538
        if err != nil {
147✔
3539
                return 0, ErrInvalidValue
×
3540
        }
×
3541

3542
        num, ok := convVal.(int64)
147✔
3543
        if !ok {
147✔
3544
                return 0, ErrInvalidValue
×
3545
        }
×
3546

3547
        if num > math.MaxInt32 {
147✔
3548
                return 0, ErrInvalidValue
×
3549
        }
×
3550

3551
        return int(num), nil
147✔
3552
}
3553

3554
func (stmt *SelectStmt) Alias() string {
167✔
3555
        if stmt.as == "" {
333✔
3556
                return stmt.ds.Alias()
166✔
3557
        }
166✔
3558

3559
        return stmt.as
1✔
3560
}
3561

3562
func (stmt *SelectStmt) hasTxMetadata() bool {
888✔
3563
        for _, sel := range stmt.targetSelectors() {
1,715✔
3564
                switch s := sel.(type) {
827✔
3565
                case *ColSelector:
713✔
3566
                        if s.col == txMetadataCol {
714✔
3567
                                return true
1✔
3568
                        }
1✔
3569
                case *JSONSelector:
21✔
3570
                        if s.ColSelector.col == txMetadataCol {
24✔
3571
                                return true
3✔
3572
                        }
3✔
3573
                }
3574
        }
3575
        return false
884✔
3576
}
3577

3578
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
980✔
3579
        groupByCols, orderByCols := stmt.groupByOrdExps(), stmt.orderBy
980✔
3580

980✔
3581
        tableRef, isTableRef := stmt.ds.(*tableRef)
980✔
3582
        if !isTableRef {
1,055✔
3583
                groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
75✔
3584

75✔
3585
                return &ScanSpecs{
75✔
3586
                        groupBySortExps: groupByCols,
75✔
3587
                        orderBySortExps: orderByCols,
75✔
3588
                }, nil
75✔
3589
        }
75✔
3590

3591
        table, err := tableRef.referencedTable(tx)
905✔
3592
        if err != nil {
920✔
3593
                if tx.engine.tableResolveFor(tableRef.table) != nil {
16✔
3594
                        return &ScanSpecs{
1✔
3595
                                groupBySortExps: groupByCols,
1✔
3596
                                orderBySortExps: orderByCols,
1✔
3597
                        }, nil
1✔
3598
                }
1✔
3599
                return nil, err
14✔
3600
        }
3601

3602
        rangesByColID := make(map[uint32]*typedValueRange)
890✔
3603
        if stmt.where != nil {
1,413✔
3604
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
523✔
3605
                if err != nil {
525✔
3606
                        return nil, err
2✔
3607
                }
2✔
3608
        }
3609

3610
        preferredIndex, err := stmt.getPreferredIndex(table)
888✔
3611
        if err != nil {
888✔
3612
                return nil, err
×
3613
        }
×
3614

3615
        var sortingIndex *Index
888✔
3616
        if preferredIndex == nil {
1,746✔
3617
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
858✔
3618
        } else {
888✔
3619
                sortingIndex = preferredIndex
30✔
3620
        }
30✔
3621

3622
        if sortingIndex == nil {
1,653✔
3623
                sortingIndex = table.primaryIndex
765✔
3624
        }
765✔
3625

3626
        if tableRef.history && !sortingIndex.IsPrimary() {
888✔
3627
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3628
        }
×
3629

3630
        if tableRef.diff && !sortingIndex.IsPrimary() {
888✔
3631
                return nil, fmt.Errorf("%w: diff queries are supported over primary index", ErrIllegalArguments)
×
3632
        }
×
3633

3634
        var descOrder bool
888✔
3635
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
905✔
3636
                groupByCols = nil
17✔
3637
        }
17✔
3638

3639
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
990✔
3640
                descOrder = orderByCols[0].descOrder
102✔
3641
                orderByCols = nil
102✔
3642
        }
102✔
3643

3644
        groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
888✔
3645

888✔
3646
        return &ScanSpecs{
888✔
3647
                Index:             sortingIndex,
888✔
3648
                rangesByColID:     rangesByColID,
888✔
3649
                IncludeHistory:    tableRef.history,
888✔
3650
                IncludeDiff:       tableRef.diff,
888✔
3651
                IncludeTxMetadata: stmt.hasTxMetadata(),
888✔
3652
                DescOrder:         descOrder,
888✔
3653
                groupBySortExps:   groupByCols,
888✔
3654
                orderBySortExps:   orderByCols,
888✔
3655
        }, nil
888✔
3656
}
3657

3658
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdExp, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
858✔
3659
        sortCols := groupByCols
858✔
3660
        if len(sortCols) == 0 {
1,690✔
3661
                sortCols = orderByCols
832✔
3662
        }
832✔
3663

3664
        if len(sortCols) == 0 {
1,577✔
3665
                return nil
719✔
3666
        }
719✔
3667

3668
        for _, idx := range table.indexes {
364✔
3669
                if idx.coversOrdCols(sortCols, rangesByColId) {
318✔
3670
                        return idx
93✔
3671
                }
93✔
3672
        }
3673
        return nil
46✔
3674
}
3675

3676
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
888✔
3677
        if len(stmt.indexOn) == 0 {
1,746✔
3678
                return nil, nil
858✔
3679
        }
858✔
3680

3681
        cols := make([]*Column, len(stmt.indexOn))
30✔
3682
        for i, colName := range stmt.indexOn {
80✔
3683
                col, err := table.GetColumnByName(colName)
50✔
3684
                if err != nil {
50✔
3685
                        return nil, err
×
3686
                }
×
3687

3688
                cols[i] = col
50✔
3689
        }
3690
        return table.GetIndexByName(indexName(table.name, cols))
30✔
3691
}
3692

3693
type UnionStmt struct {
3694
        distinct    bool
3695
        left, right DataSource
3696
}
3697

3698
func (stmt *UnionStmt) readOnly() bool {
1✔
3699
        return true
1✔
3700
}
1✔
3701

3702
func (stmt *UnionStmt) requiredPrivileges() []SQLPrivilege {
1✔
3703
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3704
}
1✔
3705

3706
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3707
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
3708
        if err != nil {
1✔
3709
                return err
×
3710
        }
×
3711
        return stmt.right.inferParameters(ctx, tx, params)
1✔
3712
}
3713

3714
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3715
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
3716
        if err != nil {
9✔
3717
                return tx, err
×
3718
        }
×
3719

3720
        return stmt.right.execAt(ctx, tx, params)
9✔
3721
}
3722

3723
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3724
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
11✔
3725
        if err != nil {
12✔
3726
                return nil, err
1✔
3727
        }
1✔
3728
        defer func() {
20✔
3729
                if err != nil {
14✔
3730
                        leftRowReader.Close()
4✔
3731
                }
4✔
3732
        }()
3733

3734
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
10✔
3735
        if err != nil {
11✔
3736
                return nil, err
1✔
3737
        }
1✔
3738
        defer func() {
18✔
3739
                if err != nil {
12✔
3740
                        rightRowReader.Close()
3✔
3741
                }
3✔
3742
        }()
3743

3744
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
9✔
3745
        if err != nil {
12✔
3746
                return nil, err
3✔
3747
        }
3✔
3748

3749
        return rowReader, nil
6✔
3750
}
3751

3752
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3753
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
11✔
3754
        if err != nil {
16✔
3755
                return nil, err
5✔
3756
        }
5✔
3757
        defer func() {
12✔
3758
                if err != nil {
7✔
3759
                        rowReader.Close()
1✔
3760
                }
1✔
3761
        }()
3762

3763
        if stmt.distinct {
11✔
3764
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
5✔
3765
                if err != nil {
6✔
3766
                        return nil, err
1✔
3767
                }
1✔
3768
                rowReader = distinctReader
4✔
3769
        }
3770

3771
        return rowReader, nil
5✔
3772
}
3773

3774
func (stmt *UnionStmt) Alias() string {
×
3775
        return ""
×
3776
}
×
3777

3778
func NewTableRef(table string, as string) *tableRef {
179✔
3779
        return &tableRef{
179✔
3780
                table: table,
179✔
3781
                as:    as,
179✔
3782
        }
179✔
3783
}
179✔
3784

3785
type tableRef struct {
3786
        table   string
3787
        history bool
3788
        diff    bool
3789
        period  period
3790
        as      string
3791
}
3792

3793
func (ref *tableRef) readOnly() bool {
1✔
3794
        return true
1✔
3795
}
1✔
3796

3797
func (ref *tableRef) requiredPrivileges() []SQLPrivilege {
1✔
3798
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3799
}
1✔
3800

3801
type period struct {
3802
        start *openPeriod
3803
        end   *openPeriod
3804
}
3805

3806
type openPeriod struct {
3807
        inclusive bool
3808
        instant   periodInstant
3809
}
3810

3811
type periodInstant struct {
3812
        exp         ValueExp
3813
        instantType instantType
3814
}
3815

3816
type instantType = int
3817

3818
const (
3819
        txInstant instantType = iota
3820
        timeInstant
3821
)
3822

3823
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
89✔
3824
        exp, err := i.exp.substitute(params)
89✔
3825
        if err != nil {
89✔
3826
                return 0, err
×
3827
        }
×
3828

3829
        instantVal, err := exp.reduce(tx, nil, "")
89✔
3830
        if err != nil {
91✔
3831
                return 0, err
2✔
3832
        }
2✔
3833

3834
        if i.instantType == txInstant {
140✔
3835
                txID, ok := instantVal.RawValue().(int64)
53✔
3836
                if !ok {
53✔
3837
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3838
                }
×
3839

3840
                if txID <= 0 {
60✔
3841
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3842
                }
7✔
3843

3844
                if inclusive {
74✔
3845
                        return uint64(txID), nil
28✔
3846
                }
28✔
3847

3848
                if asc {
29✔
3849
                        return uint64(txID + 1), nil
11✔
3850
                }
11✔
3851

3852
                if txID <= 1 {
8✔
3853
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3854
                }
1✔
3855

3856
                return uint64(txID - 1), nil
6✔
3857
        } else {
34✔
3858

34✔
3859
                var ts time.Time
34✔
3860

34✔
3861
                if instantVal.Type() == TimestampType {
67✔
3862
                        ts = instantVal.RawValue().(time.Time)
33✔
3863
                } else {
34✔
3864
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
3865
                        if err != nil {
1✔
3866
                                return 0, err
×
3867
                        }
×
3868

3869
                        tval, err := conv(instantVal)
1✔
3870
                        if err != nil {
1✔
3871
                                return 0, err
×
3872
                        }
×
3873

3874
                        ts = tval.RawValue().(time.Time)
1✔
3875
                }
3876

3877
                sts := ts
34✔
3878

34✔
3879
                if asc {
57✔
3880
                        if !inclusive {
34✔
3881
                                sts = sts.Add(1 * time.Second)
11✔
3882
                        }
11✔
3883

3884
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
23✔
3885
                        if err != nil {
34✔
3886
                                return 0, err
11✔
3887
                        }
11✔
3888

3889
                        return txHdr.ID, nil
12✔
3890
                }
3891

3892
                if !inclusive {
11✔
3893
                        sts = sts.Add(-1 * time.Second)
×
3894
                }
×
3895

3896
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
3897
                if err != nil {
11✔
3898
                        return 0, err
×
3899
                }
×
3900

3901
                return txHdr.ID, nil
11✔
3902
        }
3903
}
3904

3905
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
4,206✔
3906
        table, err := tx.catalog.GetTableByName(stmt.table)
4,206✔
3907
        if err != nil {
4,226✔
3908
                return nil, err
20✔
3909
        }
20✔
3910
        return table, nil
4,186✔
3911
}
3912

3913
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3914
        return nil
1✔
3915
}
1✔
3916

3917
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3918
        return tx, nil
×
3919
}
×
3920

3921
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
918✔
3922
        if tx == nil {
918✔
3923
                return nil, ErrIllegalArguments
×
3924
        }
×
3925

3926
        table, err := stmt.referencedTable(tx)
918✔
3927
        if err == nil {
1,835✔
3928
                if stmt.diff {
925✔
3929
                        return newDiffRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
8✔
3930
                }
8✔
3931
                return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
909✔
3932
        }
3933

3934
        if resolver := tx.engine.tableResolveFor(stmt.table); resolver != nil {
2✔
3935
                return resolver.Resolve(ctx, tx, stmt.Alias())
1✔
3936
        }
1✔
3937
        return nil, err
×
3938
}
3939

3940
func (stmt *tableRef) Alias() string {
688✔
3941
        if stmt.as == "" {
1,216✔
3942
                return stmt.table
528✔
3943
        }
528✔
3944
        return stmt.as
160✔
3945
}
3946

3947
type valuesDataSource struct {
3948
        inferTypes bool
3949
        rows       []*RowSpec
3950
}
3951

3952
func NewValuesDataSource(rows []*RowSpec) *valuesDataSource {
120✔
3953
        return &valuesDataSource{
120✔
3954
                rows: rows,
120✔
3955
        }
120✔
3956
}
120✔
3957

3958
func (ds *valuesDataSource) readOnly() bool {
×
3959
        return true
×
3960
}
×
3961

3962
func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege {
97✔
3963
        return nil
97✔
3964
}
97✔
3965

3966
func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3967
        return tx, nil
×
3968
}
×
3969

3970
func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
3971
        return nil
×
3972
}
×
3973

3974
func (ds *valuesDataSource) Alias() string {
×
3975
        return ""
×
3976
}
×
3977

3978
func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
2,371✔
3979
        if tx == nil {
2,371✔
3980
                return nil, ErrIllegalArguments
×
3981
        }
×
3982

3983
        cols := make([]ColDescriptor, len(ds.rows[0].Values))
2,371✔
3984
        for i := range cols {
11,439✔
3985
                cols[i] = ColDescriptor{
9,068✔
3986
                        Type:   AnyType,
9,068✔
3987
                        Column: fmt.Sprintf("col%d", i),
9,068✔
3988
                }
9,068✔
3989
        }
9,068✔
3990

3991
        emptyColsDesc, emptyParams := map[string]ColDescriptor{}, map[string]string{}
2,371✔
3992

2,371✔
3993
        if ds.inferTypes {
2,384✔
3994
                for i := 0; i < len(cols); i++ {
56✔
3995
                        t := AnyType
43✔
3996
                        for j := 0; j < len(ds.rows); j++ {
154✔
3997
                                e, err := ds.rows[j].Values[i].substitute(params)
111✔
3998
                                if err != nil {
111✔
3999
                                        return nil, err
×
4000
                                }
×
4001

4002
                                it, err := e.inferType(emptyColsDesc, emptyParams, "")
111✔
4003
                                if err != nil {
111✔
4004
                                        return nil, err
×
4005
                                }
×
4006

4007
                                if t == AnyType {
154✔
4008
                                        t = it
43✔
4009
                                } else if t != it && it != AnyType {
113✔
4010
                                        return nil, fmt.Errorf("cannot match types %s and %s", t, it)
2✔
4011
                                }
2✔
4012
                        }
4013
                        cols[i].Type = t
41✔
4014
                }
4015
        }
4016

4017
        values := make([][]ValueExp, len(ds.rows))
2,369✔
4018
        for i, rowSpec := range ds.rows {
4,847✔
4019
                values[i] = rowSpec.Values
2,478✔
4020
        }
2,478✔
4021
        return NewValuesRowReader(tx, params, cols, ds.inferTypes, "values", values)
2,369✔
4022
}
4023

4024
type JoinSpec struct {
4025
        joinType JoinType
4026
        ds       DataSource
4027
        cond     ValueExp
4028
        indexOn  []string
4029
}
4030

4031
type OrdExp struct {
4032
        exp       ValueExp
4033
        descOrder bool
4034
}
4035

4036
func (oc *OrdExp) AsSelector() Selector {
714✔
4037
        sel, ok := oc.exp.(Selector)
714✔
4038
        if ok {
1,374✔
4039
                return sel
660✔
4040
        }
660✔
4041
        return nil
54✔
4042
}
4043

4044
func NewOrdCol(table string, col string, descOrder bool) *OrdExp {
1✔
4045
        return &OrdExp{
1✔
4046
                exp:       NewColSelector(table, col),
1✔
4047
                descOrder: descOrder,
1✔
4048
        }
1✔
4049
}
1✔
4050

4051
type Selector interface {
4052
        ValueExp
4053
        resolve(implicitTable string) (aggFn, table, col string)
4054
}
4055

4056
type ColSelector struct {
4057
        table string
4058
        col   string
4059
}
4060

4061
func NewColSelector(table, col string) *ColSelector {
126✔
4062
        return &ColSelector{
126✔
4063
                table: table,
126✔
4064
                col:   col,
126✔
4065
        }
126✔
4066
}
126✔
4067

4068
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,187,729✔
4069
        table = implicitTable
1,187,729✔
4070
        if sel.table != "" {
1,575,201✔
4071
                table = sel.table
387,472✔
4072
        }
387,472✔
4073
        return "", table, sel.col
1,187,729✔
4074
}
4075

4076
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
695✔
4077
        _, table, col := sel.resolve(implicitTable)
695✔
4078
        encSel := EncodeSelector("", table, col)
695✔
4079

695✔
4080
        desc, ok := cols[encSel]
695✔
4081
        if !ok {
698✔
4082
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
4083
        }
3✔
4084
        return desc.Type, nil
692✔
4085
}
4086

4087
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
4088
        _, table, col := sel.resolve(implicitTable)
15✔
4089
        encSel := EncodeSelector("", table, col)
15✔
4090

15✔
4091
        desc, ok := cols[encSel]
15✔
4092
        if !ok {
17✔
4093
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
4094
        }
2✔
4095

4096
        if desc.Type != t {
16✔
4097
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
4098
        }
3✔
4099

4100
        return nil
10✔
4101
}
4102

4103
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
205,810✔
4104
        return sel, nil
205,810✔
4105
}
205,810✔
4106

4107
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
991,250✔
4108
        if row == nil {
991,251✔
4109
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
4110
        }
1✔
4111

4112
        aggFn, table, col := sel.resolve(implicitTable)
991,249✔
4113

991,249✔
4114
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
991,249✔
4115
        if !ok {
991,256✔
4116
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
4117
        }
7✔
4118
        return v, nil
991,242✔
4119
}
4120

4121
func (sel *ColSelector) selectors() []Selector {
937✔
4122
        return []Selector{sel}
937✔
4123
}
937✔
4124

4125
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
568✔
4126
        aggFn, table, col := sel.resolve(implicitTable)
568✔
4127

568✔
4128
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
568✔
4129
        if !ok {
846✔
4130
                return sel
278✔
4131
        }
278✔
4132

4133
        return v
290✔
4134
}
4135

4136
func (sel *ColSelector) isConstant() bool {
12✔
4137
        return false
12✔
4138
}
12✔
4139

4140
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
4141
        return nil
11✔
4142
}
11✔
4143

4144
func (sel *ColSelector) String() string {
48✔
4145
        return sel.col
48✔
4146
}
48✔
4147

4148
type AggColSelector struct {
4149
        aggFn AggregateFn
4150
        table string
4151
        col   string
4152
}
4153

4154
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
4155
        return &AggColSelector{
16✔
4156
                aggFn: aggFn,
16✔
4157
                table: table,
16✔
4158
                col:   col,
16✔
4159
        }
16✔
4160
}
16✔
4161

4162
func EncodeSelector(aggFn, table, col string) string {
1,859,912✔
4163
        return aggFn + "(" + table + "." + col + ")"
1,859,912✔
4164
}
1,859,912✔
4165

4166
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,628✔
4167
        table = implicitTable
1,628✔
4168
        if sel.table != "" {
1,801✔
4169
                table = sel.table
173✔
4170
        }
173✔
4171
        return sel.aggFn, table, sel.col
1,628✔
4172
}
4173

4174
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
36✔
4175
        if sel.aggFn == COUNT {
55✔
4176
                return IntegerType, nil
19✔
4177
        }
19✔
4178

4179
        colSelector := &ColSelector{table: sel.table, col: sel.col}
17✔
4180

17✔
4181
        if sel.aggFn == SUM || sel.aggFn == AVG {
24✔
4182
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
4183
                if err != nil {
7✔
4184
                        return AnyType, err
×
4185
                }
×
4186

4187
                if t != IntegerType && t != Float64Type {
7✔
4188
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
4189

×
4190
                }
×
4191

4192
                return t, nil
7✔
4193
        }
4194

4195
        return colSelector.inferType(cols, params, implicitTable)
10✔
4196
}
4197

4198
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
4199
        if sel.aggFn == COUNT {
10✔
4200
                if t != IntegerType {
3✔
4201
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4202
                }
1✔
4203
                return nil
1✔
4204
        }
4205

4206
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
4207

6✔
4208
        if sel.aggFn == SUM || sel.aggFn == AVG {
10✔
4209
                if t != IntegerType && t != Float64Type {
5✔
4210
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
4211
                }
1✔
4212
        }
4213

4214
        return colSelector.requiresType(t, cols, params, implicitTable)
5✔
4215
}
4216

4217
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
412✔
4218
        return sel, nil
412✔
4219
}
412✔
4220

4221
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
501✔
4222
        if row == nil {
502✔
4223
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
4224
        }
1✔
4225

4226
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
500✔
4227
        if !ok {
501✔
4228
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
4229
        }
1✔
4230
        return v, nil
499✔
4231
}
4232

4233
func (sel *AggColSelector) selectors() []Selector {
232✔
4234
        return []Selector{sel}
232✔
4235
}
232✔
4236

4237
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4238
        return sel
×
4239
}
×
4240

4241
func (sel *AggColSelector) isConstant() bool {
1✔
4242
        return false
1✔
4243
}
1✔
4244

4245
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4246
        return nil
×
4247
}
×
4248

4249
func (sel *AggColSelector) String() string {
×
4250
        return sel.aggFn + "(" + sel.col + ")"
×
4251
}
×
4252

4253
type NumExp struct {
4254
        op          NumOperator
4255
        left, right ValueExp
4256
}
4257

4258
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
17✔
4259
        // First step - check if we can infer the type of sub-expressions
17✔
4260
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
17✔
4261
        if err != nil {
17✔
4262
                return AnyType, err
×
4263
        }
×
4264
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType {
17✔
4265
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
4266
        }
×
4267

4268
        tright, err := bexp.right.inferType(cols, params, implicitTable)
17✔
4269
        if err != nil {
17✔
4270
                return AnyType, err
×
4271
        }
×
4272
        if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType {
19✔
4273
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
4274
        }
2✔
4275

4276
        if tleft == IntegerType && tright == IntegerType {
19✔
4277
                // Both sides are integer types - the result is also integer
4✔
4278
                return IntegerType, nil
4✔
4279
        }
4✔
4280

4281
        if tleft != AnyType && tright != AnyType {
20✔
4282
                // Both sides have concrete types but at least one of them is float
9✔
4283
                return Float64Type, nil
9✔
4284
        }
9✔
4285

4286
        // Both sides are ambiguous
4287
        return AnyType, nil
2✔
4288
}
4289

4290
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
4291
        ret := make(map[string]SQLValueType, len(params))
11✔
4292
        for k, v := range params {
15✔
4293
                ret[k] = v
4✔
4294
        }
4✔
4295
        return ret
11✔
4296
}
4297

4298
func restoreParams(params, restore map[string]SQLValueType) {
2✔
4299
        for k := range params {
2✔
4300
                delete(params, k)
×
4301
        }
×
4302
        for k, v := range restore {
2✔
4303
                params[k] = v
×
4304
        }
×
4305
}
4306

4307
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4308
        if t != IntegerType && t != Float64Type {
8✔
4309
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4310
        }
1✔
4311

4312
        floatArgs := 2
6✔
4313
        paramsOrig := copyParams(params)
6✔
4314
        err := bexp.left.requiresType(t, cols, params, implicitTable)
6✔
4315
        if err != nil && t == Float64Type {
7✔
4316
                restoreParams(params, paramsOrig)
1✔
4317
                floatArgs--
1✔
4318
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
4319
        }
1✔
4320
        if err != nil {
7✔
4321
                return err
1✔
4322
        }
1✔
4323

4324
        paramsOrig = copyParams(params)
5✔
4325
        err = bexp.right.requiresType(t, cols, params, implicitTable)
5✔
4326
        if err != nil && t == Float64Type {
6✔
4327
                restoreParams(params, paramsOrig)
1✔
4328
                floatArgs--
1✔
4329
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
4330
        }
1✔
4331
        if err != nil {
7✔
4332
                return err
2✔
4333
        }
2✔
4334

4335
        if t == Float64Type && floatArgs == 0 {
3✔
4336
                // Currently this case requires explicit float cast
×
4337
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
4338
        }
×
4339

4340
        return nil
3✔
4341
}
4342

4343
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
187✔
4344
        rlexp, err := bexp.left.substitute(params)
187✔
4345
        if err != nil {
187✔
4346
                return nil, err
×
4347
        }
×
4348

4349
        rrexp, err := bexp.right.substitute(params)
187✔
4350
        if err != nil {
187✔
4351
                return nil, err
×
4352
        }
×
4353

4354
        bexp.left = rlexp
187✔
4355
        bexp.right = rrexp
187✔
4356

187✔
4357
        return bexp, nil
187✔
4358
}
4359

4360
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
177,635✔
4361
        vl, err := bexp.left.reduce(tx, row, implicitTable)
177,635✔
4362
        if err != nil {
177,635✔
4363
                return nil, err
×
4364
        }
×
4365

4366
        vr, err := bexp.right.reduce(tx, row, implicitTable)
177,635✔
4367
        if err != nil {
177,635✔
4368
                return nil, err
×
4369
        }
×
4370

4371
        vl = unwrapJSON(vl)
177,635✔
4372
        vr = unwrapJSON(vr)
177,635✔
4373

177,635✔
4374
        return applyNumOperator(bexp.op, vl, vr)
177,635✔
4375
}
4376

4377
func unwrapJSON(v TypedValue) TypedValue {
355,270✔
4378
        if jsonVal, ok := v.(*JSON); ok {
355,370✔
4379
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
200✔
4380
                        return sv
100✔
4381
                }
100✔
4382
        }
4383
        return v
355,170✔
4384
}
4385

4386
func (bexp *NumExp) selectors() []Selector {
13✔
4387
        return append(bexp.left.selectors(), bexp.right.selectors()...)
13✔
4388
}
13✔
4389

4390
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4391
        return &NumExp{
1✔
4392
                op:    bexp.op,
1✔
4393
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
4394
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
4395
        }
1✔
4396
}
1✔
4397

4398
func (bexp *NumExp) isConstant() bool {
5✔
4399
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
4400
}
5✔
4401

4402
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
4✔
4403
        return nil
4✔
4404
}
4✔
4405

4406
func (bexp *NumExp) String() string {
18✔
4407
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
18✔
4408
}
18✔
4409

4410
type NotBoolExp struct {
4411
        exp ValueExp
4412
}
4413

4414
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4415
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
1✔
4416
        if err != nil {
1✔
4417
                return AnyType, err
×
4418
        }
×
4419

4420
        return BooleanType, nil
1✔
4421
}
4422

4423
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
4424
        if t != BooleanType {
7✔
4425
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4426
        }
1✔
4427

4428
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
5✔
4429
}
4430

4431
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
4432
        rexp, err := bexp.exp.substitute(params)
22✔
4433
        if err != nil {
22✔
4434
                return nil, err
×
4435
        }
×
4436

4437
        bexp.exp = rexp
22✔
4438

22✔
4439
        return bexp, nil
22✔
4440
}
4441

4442
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
4443
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
4444
        if err != nil {
22✔
4445
                return nil, err
×
4446
        }
×
4447

4448
        r, isBool := v.RawValue().(bool)
22✔
4449
        if !isBool {
22✔
4450
                return nil, ErrInvalidCondition
×
4451
        }
×
4452

4453
        return &Bool{val: !r}, nil
22✔
4454
}
4455

4456
func (bexp *NotBoolExp) selectors() []Selector {
×
4457
        return bexp.exp.selectors()
×
4458
}
×
4459

4460
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4461
        return &NotBoolExp{
×
4462
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
4463
        }
×
4464
}
×
4465

4466
func (bexp *NotBoolExp) isConstant() bool {
1✔
4467
        return bexp.exp.isConstant()
1✔
4468
}
1✔
4469

4470
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
4471
        return nil
7✔
4472
}
7✔
4473

4474
func (bexp *NotBoolExp) String() string {
12✔
4475
        return fmt.Sprintf("(NOT %s)", bexp.exp.String())
12✔
4476
}
12✔
4477

4478
type LikeBoolExp struct {
4479
        val     ValueExp
4480
        notLike bool
4481
        pattern ValueExp
4482
}
4483

4484
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
4✔
4485
        return &LikeBoolExp{
4✔
4486
                val:     val,
4✔
4487
                notLike: notLike,
4✔
4488
                pattern: pattern,
4✔
4489
        }
4✔
4490
}
4✔
4491

4492
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4493
        if bexp.val == nil || bexp.pattern == nil {
4✔
4494
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4495
        }
1✔
4496

4497
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
2✔
4498
        if err != nil {
3✔
4499
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4500
        }
1✔
4501

4502
        return BooleanType, nil
1✔
4503
}
4504

4505
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4506
        if bexp.val == nil || bexp.pattern == nil {
9✔
4507
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
4508
        }
2✔
4509

4510
        if t != BooleanType {
7✔
4511
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
4512
        }
2✔
4513

4514
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
4515
        if err != nil {
4✔
4516
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4517
        }
1✔
4518

4519
        return nil
2✔
4520
}
4521

4522
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
135✔
4523
        if bexp.val == nil || bexp.pattern == nil {
136✔
4524
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4525
        }
1✔
4526

4527
        val, err := bexp.val.substitute(params)
134✔
4528
        if err != nil {
134✔
4529
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4530
        }
×
4531

4532
        pattern, err := bexp.pattern.substitute(params)
134✔
4533
        if err != nil {
134✔
4534
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4535
        }
×
4536

4537
        return &LikeBoolExp{
134✔
4538
                val:     val,
134✔
4539
                notLike: bexp.notLike,
134✔
4540
                pattern: pattern,
134✔
4541
        }, nil
134✔
4542
}
4543

4544
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
142✔
4545
        if bexp.val == nil || bexp.pattern == nil {
143✔
4546
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4547
        }
1✔
4548

4549
        rval, err := bexp.val.reduce(tx, row, implicitTable)
141✔
4550
        if err != nil {
141✔
4551
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4552
        }
×
4553

4554
        if rval.IsNull() {
142✔
4555
                return &Bool{val: bexp.notLike}, nil
1✔
4556
        }
1✔
4557

4558
        rvalStr, ok := rval.RawValue().(string)
140✔
4559
        if !ok {
141✔
4560
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
4561
        }
1✔
4562

4563
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
139✔
4564
        if err != nil {
139✔
4565
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4566
        }
×
4567

4568
        if rpattern.Type() != VarcharType {
139✔
4569
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
4570
        }
×
4571

4572
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
139✔
4573
        if err != nil {
139✔
4574
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4575
        }
×
4576

4577
        return &Bool{val: matched != bexp.notLike}, nil
139✔
4578
}
4579

4580
func (bexp *LikeBoolExp) selectors() []Selector {
1✔
4581
        return bexp.val.selectors()
1✔
4582
}
1✔
4583

4584
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4585
        return bexp
1✔
4586
}
1✔
4587

4588
func (bexp *LikeBoolExp) isConstant() bool {
2✔
4589
        return false
2✔
4590
}
2✔
4591

4592
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
4593
        return nil
8✔
4594
}
8✔
4595

4596
func (bexp *LikeBoolExp) String() string {
5✔
4597
        fmtStr := "(%s LIKE %s)"
5✔
4598
        if bexp.notLike {
6✔
4599
                fmtStr = "(%s NOT LIKE %s)"
1✔
4600
        }
1✔
4601
        return fmt.Sprintf(fmtStr, bexp.val.String(), bexp.pattern.String())
5✔
4602
}
4603

4604
type CmpBoolExp struct {
4605
        op          CmpOperator
4606
        left, right ValueExp
4607
}
4608

4609
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
67✔
4610
        return &CmpBoolExp{
67✔
4611
                op:    op,
67✔
4612
                left:  left,
67✔
4613
                right: right,
67✔
4614
        }
67✔
4615
}
67✔
4616

4617
func (bexp *CmpBoolExp) Left() ValueExp {
×
4618
        return bexp.left
×
4619
}
×
4620

4621
func (bexp *CmpBoolExp) Right() ValueExp {
×
4622
        return bexp.right
×
4623
}
×
4624

4625
func (bexp *CmpBoolExp) OP() CmpOperator {
×
4626
        return bexp.op
×
4627
}
×
4628

4629
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
63✔
4630
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
63✔
4631
        if err != nil {
63✔
4632
                return AnyType, err
×
4633
        }
×
4634

4635
        tright, err := bexp.right.inferType(cols, params, implicitTable)
63✔
4636
        if err != nil {
65✔
4637
                return AnyType, err
2✔
4638
        }
2✔
4639

4640
        // unification step
4641

4642
        if tleft == tright {
74✔
4643
                return BooleanType, nil
13✔
4644
        }
13✔
4645

4646
        _, ok := coerceTypes(tleft, tright)
48✔
4647
        if !ok {
52✔
4648
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4649
        }
4✔
4650

4651
        if tleft == AnyType {
47✔
4652
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
3✔
4653
                if err != nil {
3✔
4654
                        return AnyType, err
×
4655
                }
×
4656
        }
4657

4658
        if tright == AnyType {
84✔
4659
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
40✔
4660
                if err != nil {
41✔
4661
                        return AnyType, err
1✔
4662
                }
1✔
4663
        }
4664
        return BooleanType, nil
43✔
4665
}
4666

4667
func coerceTypes(t1, t2 SQLValueType) (SQLValueType, bool) {
48✔
4668
        switch {
48✔
4669
        case t1 == t2:
×
4670
                return t1, true
×
4671
        case t1 == AnyType:
3✔
4672
                return t2, true
3✔
4673
        case t2 == AnyType:
40✔
4674
                return t1, true
40✔
4675
        case (t1 == IntegerType && t2 == Float64Type) ||
4676
                (t1 == Float64Type && t2 == IntegerType):
1✔
4677
                return Float64Type, true
1✔
4678
        }
4679
        return "", false
4✔
4680
}
4681

4682
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
41✔
4683
        if t != BooleanType {
42✔
4684
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4685
        }
1✔
4686

4687
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4688
        return err
40✔
4689
}
4690

4691
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
14,330✔
4692
        rlexp, err := bexp.left.substitute(params)
14,330✔
4693
        if err != nil {
14,330✔
4694
                return nil, err
×
4695
        }
×
4696

4697
        rrexp, err := bexp.right.substitute(params)
14,330✔
4698
        if err != nil {
14,331✔
4699
                return nil, err
1✔
4700
        }
1✔
4701

4702
        bexp.left = rlexp
14,329✔
4703
        bexp.right = rrexp
14,329✔
4704

14,329✔
4705
        return bexp, nil
14,329✔
4706
}
4707

4708
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
13,920✔
4709
        vl, err := bexp.left.reduce(tx, row, implicitTable)
13,920✔
4710
        if err != nil {
13,923✔
4711
                return nil, err
3✔
4712
        }
3✔
4713

4714
        vr, err := bexp.right.reduce(tx, row, implicitTable)
13,917✔
4715
        if err != nil {
13,919✔
4716
                return nil, err
2✔
4717
        }
2✔
4718

4719
        r, err := vl.Compare(vr)
13,915✔
4720
        if err != nil {
13,919✔
4721
                return nil, err
4✔
4722
        }
4✔
4723

4724
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
13,911✔
4725
}
4726

4727
func (bexp *CmpBoolExp) selectors() []Selector {
12✔
4728
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12✔
4729
}
12✔
4730

4731
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
282✔
4732
        return &CmpBoolExp{
282✔
4733
                op:    bexp.op,
282✔
4734
                left:  bexp.left.reduceSelectors(row, implicitTable),
282✔
4735
                right: bexp.right.reduceSelectors(row, implicitTable),
282✔
4736
        }
282✔
4737
}
282✔
4738

4739
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4740
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4741
}
2✔
4742

4743
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
612✔
4744
        matchingFunc := func(_, right ValueExp) (*ColSelector, ValueExp, bool) {
1,433✔
4745
                s, isSel := bexp.left.(*ColSelector)
821✔
4746
                if isSel && s.col != revCol && bexp.right.isConstant() {
1,224✔
4747
                        return s, right, true
403✔
4748
                }
403✔
4749
                return nil, nil, false
418✔
4750
        }
4751

4752
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
612✔
4753
        if !ok {
821✔
4754
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
209✔
4755
        }
209✔
4756

4757
        if !ok {
821✔
4758
                return nil
209✔
4759
        }
209✔
4760

4761
        aggFn, t, col := sel.resolve(table.name)
403✔
4762
        if aggFn != "" || t != asTable {
417✔
4763
                return nil
14✔
4764
        }
14✔
4765

4766
        column, err := table.GetColumnByName(col)
389✔
4767
        if err != nil {
390✔
4768
                return err
1✔
4769
        }
1✔
4770

4771
        val, err := c.substitute(params)
388✔
4772
        if errors.Is(err, ErrMissingParameter) {
447✔
4773
                // TODO: not supported when parameters are not provided during query resolution
59✔
4774
                return nil
59✔
4775
        }
59✔
4776
        if err != nil {
329✔
4777
                return err
×
4778
        }
×
4779

4780
        rval, err := val.reduce(nil, nil, table.name)
329✔
4781
        if err != nil {
330✔
4782
                return err
1✔
4783
        }
1✔
4784

4785
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
328✔
4786
}
4787

4788
func (bexp *CmpBoolExp) String() string {
20✔
4789
        opStr := CmpOperatorToString(bexp.op)
20✔
4790
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
20✔
4791
}
20✔
4792

4793
type TimestampFieldType string
4794

4795
const (
4796
        TimestampFieldTypeYear   TimestampFieldType = "YEAR"
4797
        TimestampFieldTypeMonth  TimestampFieldType = "MONTH"
4798
        TimestampFieldTypeDay    TimestampFieldType = "DAY"
4799
        TimestampFieldTypeHour   TimestampFieldType = "HOUR"
4800
        TimestampFieldTypeMinute TimestampFieldType = "MINUTE"
4801
        TimestampFieldTypeSecond TimestampFieldType = "SECOND"
4802
)
4803

4804
type ExtractFromTimestampExp struct {
4805
        Field TimestampFieldType
4806
        Exp   ValueExp
4807
}
4808

4809
func (te *ExtractFromTimestampExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4810
        inferredType, err := te.Exp.inferType(cols, params, implicitTable)
3✔
4811
        if err != nil {
3✔
4812
                return "", err
×
4813
        }
×
4814

4815
        if inferredType != TimestampType &&
3✔
4816
                inferredType != VarcharType &&
3✔
4817
                inferredType != AnyType {
3✔
4818
                return "", fmt.Errorf("timestamp expression must be of type %v or %v, but was: %v", TimestampType, VarcharType, inferredType)
×
4819
        }
×
4820
        return IntegerType, nil
3✔
4821
}
4822

4823
func (te *ExtractFromTimestampExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
4824
        if t != IntegerType && t != Float64Type {
4✔
4825
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
×
4826
        }
×
4827
        return te.Exp.requiresType(TimestampType, cols, params, implicitTable)
4✔
4828
}
4829

4830
func (te *ExtractFromTimestampExp) substitute(params map[string]interface{}) (ValueExp, error) {
18✔
4831
        exp, err := te.Exp.substitute(params)
18✔
4832
        if err != nil {
18✔
4833
                return nil, err
×
4834
        }
×
4835
        return &ExtractFromTimestampExp{
18✔
4836
                Field: te.Field,
18✔
4837
                Exp:   exp,
18✔
4838
        }, nil
18✔
4839
}
4840

4841
func (te *ExtractFromTimestampExp) selectors() []Selector {
12✔
4842
        return te.Exp.selectors()
12✔
4843
}
12✔
4844

4845
func (te *ExtractFromTimestampExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
18✔
4846
        v, err := te.Exp.reduce(tx, row, implicitTable)
18✔
4847
        if err != nil {
18✔
4848
                return nil, err
×
4849
        }
×
4850

4851
        if v.IsNull() {
18✔
4852
                return NewNull(IntegerType), nil
×
4853
        }
×
4854

4855
        if t := v.Type(); t != TimestampType && t != VarcharType {
18✔
4856
                return nil, fmt.Errorf("%w: expected type %v but found type %v", ErrInvalidTypes, TimestampType, t)
×
4857
        }
×
4858

4859
        if v.Type() == VarcharType {
22✔
4860
                converterFunc, err := getConverter(VarcharType, TimestampType)
4✔
4861
                if err != nil {
4✔
4862
                        return nil, err
×
4863
                }
×
4864
                casted, err := converterFunc(v)
4✔
4865
                if err != nil {
4✔
4866
                        return nil, err
×
4867
                }
×
4868
                v = casted
4✔
4869
        }
4870

4871
        t, _ := v.RawValue().(time.Time)
18✔
4872

18✔
4873
        year, month, day := t.Date()
18✔
4874

18✔
4875
        switch te.Field {
18✔
4876
        case TimestampFieldTypeYear:
3✔
4877
                return NewInteger(int64(year)), nil
3✔
4878
        case TimestampFieldTypeMonth:
3✔
4879
                return NewInteger(int64(month)), nil
3✔
4880
        case TimestampFieldTypeDay:
3✔
4881
                return NewInteger(int64(day)), nil
3✔
4882
        case TimestampFieldTypeHour:
3✔
4883
                return NewInteger(int64(t.Hour())), nil
3✔
4884
        case TimestampFieldTypeMinute:
3✔
4885
                return NewInteger(int64(t.Minute())), nil
3✔
4886
        case TimestampFieldTypeSecond:
3✔
4887
                return NewInteger(int64(t.Second())), nil
3✔
4888
        }
4889
        return nil, fmt.Errorf("unknown timestamp field type: %s", te.Field)
×
4890
}
4891

4892
func (te *ExtractFromTimestampExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4893
        return &ExtractFromTimestampExp{
×
4894
                Field: te.Field,
×
4895
                Exp:   te.Exp.reduceSelectors(row, implicitTable),
×
4896
        }
×
4897
}
×
4898

4899
func (te *ExtractFromTimestampExp) isConstant() bool {
1✔
4900
        return false
1✔
4901
}
1✔
4902

4903
func (te *ExtractFromTimestampExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4904
        return nil
×
4905
}
×
4906

4907
func (te *ExtractFromTimestampExp) String() string {
6✔
4908
        return fmt.Sprintf("EXTRACT(%s FROM %s)", te.Field, te.Exp)
6✔
4909
}
6✔
4910

4911
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
328✔
4912
        currRange, ranged := rangesByColID[colID]
328✔
4913
        var newRange *typedValueRange
328✔
4914

328✔
4915
        switch cmp {
328✔
4916
        case EQ:
254✔
4917
                {
508✔
4918
                        newRange = &typedValueRange{
254✔
4919
                                lRange: &typedValueSemiRange{
254✔
4920
                                        val:       val,
254✔
4921
                                        inclusive: true,
254✔
4922
                                },
254✔
4923
                                hRange: &typedValueSemiRange{
254✔
4924
                                        val:       val,
254✔
4925
                                        inclusive: true,
254✔
4926
                                },
254✔
4927
                        }
254✔
4928
                }
254✔
4929
        case LT:
13✔
4930
                {
26✔
4931
                        newRange = &typedValueRange{
13✔
4932
                                hRange: &typedValueSemiRange{
13✔
4933
                                        val: val,
13✔
4934
                                },
13✔
4935
                        }
13✔
4936
                }
13✔
4937
        case LE:
12✔
4938
                {
24✔
4939
                        newRange = &typedValueRange{
12✔
4940
                                hRange: &typedValueSemiRange{
12✔
4941
                                        val:       val,
12✔
4942
                                        inclusive: true,
12✔
4943
                                },
12✔
4944
                        }
12✔
4945
                }
12✔
4946
        case GT:
18✔
4947
                {
36✔
4948
                        newRange = &typedValueRange{
18✔
4949
                                lRange: &typedValueSemiRange{
18✔
4950
                                        val: val,
18✔
4951
                                },
18✔
4952
                        }
18✔
4953
                }
18✔
4954
        case GE:
19✔
4955
                {
38✔
4956
                        newRange = &typedValueRange{
19✔
4957
                                lRange: &typedValueSemiRange{
19✔
4958
                                        val:       val,
19✔
4959
                                        inclusive: true,
19✔
4960
                                },
19✔
4961
                        }
19✔
4962
                }
19✔
4963
        case NE:
12✔
4964
                {
24✔
4965
                        return nil
12✔
4966
                }
12✔
4967
        }
4968

4969
        if !ranged {
627✔
4970
                rangesByColID[colID] = newRange
311✔
4971
                return nil
311✔
4972
        }
311✔
4973

4974
        return currRange.refineWith(newRange)
5✔
4975
}
4976

4977
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
13,911✔
4978
        switch {
13,911✔
4979
        case cmp == 0:
1,160✔
4980
                {
2,320✔
4981
                        return op == EQ || op == LE || op == GE
1,160✔
4982
                }
1,160✔
4983
        case cmp < 0:
6,513✔
4984
                {
13,026✔
4985
                        return op == NE || op == LT || op == LE
6,513✔
4986
                }
6,513✔
4987
        case cmp > 0:
6,238✔
4988
                {
12,476✔
4989
                        return op == NE || op == GT || op == GE
6,238✔
4990
                }
6,238✔
4991
        }
4992
        return false
×
4993
}
4994

4995
type BinBoolExp struct {
4996
        op          LogicOperator
4997
        left, right ValueExp
4998
}
4999

5000
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
18✔
5001
        bexp := &BinBoolExp{
18✔
5002
                op: op,
18✔
5003
        }
18✔
5004

18✔
5005
        bexp.left = lrexp
18✔
5006
        bexp.right = rrexp
18✔
5007

18✔
5008
        return bexp
18✔
5009
}
18✔
5010

5011
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
20✔
5012
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
20✔
5013
        if err != nil {
20✔
5014
                return AnyType, err
×
5015
        }
×
5016

5017
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
20✔
5018
        if err != nil {
22✔
5019
                return AnyType, err
2✔
5020
        }
2✔
5021

5022
        return BooleanType, nil
18✔
5023
}
5024

5025
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
22✔
5026
        if t != BooleanType {
25✔
5027
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
3✔
5028
        }
3✔
5029

5030
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
5031
        if err != nil {
20✔
5032
                return err
1✔
5033
        }
1✔
5034

5035
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
18✔
5036
        if err != nil {
18✔
5037
                return err
×
5038
        }
×
5039

5040
        return nil
18✔
5041
}
5042

5043
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
576✔
5044
        rlexp, err := bexp.left.substitute(params)
576✔
5045
        if err != nil {
576✔
5046
                return nil, err
×
5047
        }
×
5048

5049
        rrexp, err := bexp.right.substitute(params)
576✔
5050
        if err != nil {
576✔
5051
                return nil, err
×
5052
        }
×
5053

5054
        bexp.left = rlexp
576✔
5055
        bexp.right = rrexp
576✔
5056

576✔
5057
        return bexp, nil
576✔
5058
}
5059

5060
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
535✔
5061
        vl, err := bexp.left.reduce(tx, row, implicitTable)
535✔
5062
        if err != nil {
536✔
5063
                return nil, err
1✔
5064
        }
1✔
5065

5066
        bl, isBool := vl.(*Bool)
534✔
5067
        if !isBool {
534✔
5068
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5069
        }
×
5070

5071
        // short-circuit evaluation
5072
        if (bl.val && bexp.op == Or) || (!bl.val && bexp.op == And) {
710✔
5073
                return &Bool{val: bl.val}, nil
176✔
5074
        }
176✔
5075

5076
        vr, err := bexp.right.reduce(tx, row, implicitTable)
358✔
5077
        if err != nil {
359✔
5078
                return nil, err
1✔
5079
        }
1✔
5080

5081
        br, isBool := vr.(*Bool)
357✔
5082
        if !isBool {
357✔
5083
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5084
        }
×
5085

5086
        switch bexp.op {
357✔
5087
        case And:
335✔
5088
                {
670✔
5089
                        return &Bool{val: bl.val && br.val}, nil
335✔
5090
                }
335✔
5091
        case Or:
22✔
5092
                {
44✔
5093
                        return &Bool{val: bl.val || br.val}, nil
22✔
5094
                }
22✔
5095
        }
5096

5097
        return nil, ErrUnexpected
×
5098
}
5099

5100
func (bexp *BinBoolExp) selectors() []Selector {
2✔
5101
        return append(bexp.left.selectors(), bexp.right.selectors()...)
2✔
5102
}
2✔
5103

5104
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
5105
        return &BinBoolExp{
15✔
5106
                op:    bexp.op,
15✔
5107
                left:  bexp.left.reduceSelectors(row, implicitTable),
15✔
5108
                right: bexp.right.reduceSelectors(row, implicitTable),
15✔
5109
        }
15✔
5110
}
15✔
5111

5112
func (bexp *BinBoolExp) isConstant() bool {
1✔
5113
        return bexp.left.isConstant() && bexp.right.isConstant()
1✔
5114
}
1✔
5115

5116
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
153✔
5117
        if bexp.op == And {
292✔
5118
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
139✔
5119
                if err != nil {
139✔
5120
                        return err
×
5121
                }
×
5122

5123
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
139✔
5124
        }
5125

5126
        lRanges := make(map[uint32]*typedValueRange)
14✔
5127
        rRanges := make(map[uint32]*typedValueRange)
14✔
5128

14✔
5129
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
14✔
5130
        if err != nil {
14✔
5131
                return err
×
5132
        }
×
5133

5134
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
14✔
5135
        if err != nil {
14✔
5136
                return err
×
5137
        }
×
5138

5139
        for colID, lr := range lRanges {
21✔
5140
                rr, ok := rRanges[colID]
7✔
5141
                if !ok {
9✔
5142
                        continue
2✔
5143
                }
5144

5145
                err = lr.extendWith(rr)
5✔
5146
                if err != nil {
5✔
5147
                        return err
×
5148
                }
×
5149

5150
                rangesByColID[colID] = lr
5✔
5151
        }
5152

5153
        return nil
14✔
5154
}
5155

5156
func (bexp *BinBoolExp) String() string {
31✔
5157
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
31✔
5158
}
31✔
5159

5160
type ExistsBoolExp struct {
5161
        q DataSource
5162
}
5163

5164
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5165
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5166
}
1✔
5167

5168
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5169
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5170
}
1✔
5171

5172
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5173
        return bexp, nil
1✔
5174
}
1✔
5175

5176
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5177
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
2✔
5178
}
2✔
5179

5180
func (bexp *ExistsBoolExp) selectors() []Selector {
1✔
5181
        return nil
1✔
5182
}
1✔
5183

5184
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5185
        return bexp
1✔
5186
}
1✔
5187

5188
func (bexp *ExistsBoolExp) isConstant() bool {
2✔
5189
        return false
2✔
5190
}
2✔
5191

5192
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5193
        return nil
1✔
5194
}
1✔
5195

5196
func (bexp *ExistsBoolExp) String() string {
×
5197
        return ""
×
5198
}
×
5199

5200
type InSubQueryExp struct {
5201
        val   ValueExp
5202
        notIn bool
5203
        q     *SelectStmt
5204
}
5205

5206
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5207
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5208
}
1✔
5209

5210
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5211
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5212
}
1✔
5213

5214
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5215
        return bexp, nil
1✔
5216
}
1✔
5217

5218
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5219
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
5220
}
2✔
5221

5222
func (bexp *InSubQueryExp) selectors() []Selector {
1✔
5223
        return bexp.val.selectors()
1✔
5224
}
1✔
5225

5226
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5227
        return bexp
1✔
5228
}
1✔
5229

5230
func (bexp *InSubQueryExp) isConstant() bool {
1✔
5231
        return false
1✔
5232
}
1✔
5233

5234
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5235
        return nil
1✔
5236
}
1✔
5237

5238
func (bexp *InSubQueryExp) String() string {
×
5239
        return ""
×
5240
}
×
5241

5242
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
5243
type InListExp struct {
5244
        val    ValueExp
5245
        notIn  bool
5246
        values []ValueExp
5247
}
5248

5249
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
5250
        t, err := bexp.val.inferType(cols, params, implicitTable)
6✔
5251
        if err != nil {
8✔
5252
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
2✔
5253
        }
2✔
5254

5255
        for _, v := range bexp.values {
12✔
5256
                err = v.requiresType(t, cols, params, implicitTable)
8✔
5257
                if err != nil {
9✔
5258
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
5259
                }
1✔
5260
        }
5261

5262
        return BooleanType, nil
3✔
5263
}
5264

5265
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
5266
        _, err := bexp.inferType(cols, params, implicitTable)
2✔
5267
        if err != nil {
3✔
5268
                return err
1✔
5269
        }
1✔
5270

5271
        if t != BooleanType {
1✔
5272
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
×
5273
        }
×
5274

5275
        return nil
1✔
5276
}
5277

5278
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
115✔
5279
        val, err := bexp.val.substitute(params)
115✔
5280
        if err != nil {
115✔
5281
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5282
        }
×
5283

5284
        values := make([]ValueExp, len(bexp.values))
115✔
5285

115✔
5286
        for i, val := range bexp.values {
245✔
5287
                values[i], err = val.substitute(params)
130✔
5288
                if err != nil {
130✔
5289
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5290
                }
×
5291
        }
5292

5293
        return &InListExp{
115✔
5294
                val:    val,
115✔
5295
                notIn:  bexp.notIn,
115✔
5296
                values: values,
115✔
5297
        }, nil
115✔
5298
}
5299

5300
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
115✔
5301
        rval, err := bexp.val.reduce(tx, row, implicitTable)
115✔
5302
        if err != nil {
116✔
5303
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5304
        }
1✔
5305

5306
        var found bool
114✔
5307

114✔
5308
        for _, v := range bexp.values {
241✔
5309
                rv, err := v.reduce(tx, row, implicitTable)
127✔
5310
                if err != nil {
128✔
5311
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5312
                }
1✔
5313

5314
                r, err := rval.Compare(rv)
126✔
5315
                if err != nil {
127✔
5316
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5317
                }
1✔
5318

5319
                if r == 0 {
140✔
5320
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
15✔
5321
                        found = found || true
15✔
5322
                }
15✔
5323
        }
5324

5325
        return &Bool{val: found != bexp.notIn}, nil
112✔
5326
}
5327

5328
func (bexp *InListExp) selectors() []Selector {
1✔
5329
        selectors := make([]Selector, 0, len(bexp.values))
1✔
5330
        for _, v := range bexp.values {
4✔
5331
                selectors = append(selectors, v.selectors()...)
3✔
5332
        }
3✔
5333
        return append(bexp.val.selectors(), selectors...)
1✔
5334
}
5335

5336
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
5337
        values := make([]ValueExp, len(bexp.values))
10✔
5338

10✔
5339
        for i, val := range bexp.values {
20✔
5340
                values[i] = val.reduceSelectors(row, implicitTable)
10✔
5341
        }
10✔
5342

5343
        return &InListExp{
10✔
5344
                val:    bexp.val.reduceSelectors(row, implicitTable),
10✔
5345
                values: values,
10✔
5346
        }
10✔
5347
}
5348

5349
func (bexp *InListExp) isConstant() bool {
1✔
5350
        return false
1✔
5351
}
1✔
5352

5353
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
21✔
5354
        // TODO: may be determiined by smallest and bigggest value in the list
21✔
5355
        return nil
21✔
5356
}
21✔
5357

5358
func (bexp *InListExp) String() string {
1✔
5359
        values := make([]string, len(bexp.values))
1✔
5360
        for i, exp := range bexp.values {
5✔
5361
                values[i] = exp.String()
4✔
5362
        }
4✔
5363
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
1✔
5364
}
5365

5366
type FnDataSourceStmt struct {
5367
        fnCall *FnCall
5368
        as     string
5369
}
5370

5371
func (stmt *FnDataSourceStmt) readOnly() bool {
1✔
5372
        return true
1✔
5373
}
1✔
5374

5375
func (stmt *FnDataSourceStmt) requiredPrivileges() []SQLPrivilege {
1✔
5376
        return nil
1✔
5377
}
1✔
5378

5379
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
5380
        return tx, nil
×
5381
}
×
5382

5383
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5384
        return nil
1✔
5385
}
1✔
5386

5387
func (stmt *FnDataSourceStmt) Alias() string {
24✔
5388
        if stmt.as != "" {
26✔
5389
                return stmt.as
2✔
5390
        }
2✔
5391

5392
        switch strings.ToUpper(stmt.fnCall.fn) {
22✔
5393
        case DatabasesFnCall:
3✔
5394
                {
6✔
5395
                        return "databases"
3✔
5396
                }
3✔
5397
        case TablesFnCall:
5✔
5398
                {
10✔
5399
                        return "tables"
5✔
5400
                }
5✔
5401
        case TableFnCall:
×
5402
                {
×
5403
                        return "table"
×
5404
                }
×
5405
        case UsersFnCall:
7✔
5406
                {
14✔
5407
                        return "users"
7✔
5408
                }
7✔
5409
        case ColumnsFnCall:
3✔
5410
                {
6✔
5411
                        return "columns"
3✔
5412
                }
3✔
5413
        case IndexesFnCall:
2✔
5414
                {
4✔
5415
                        return "indexes"
2✔
5416
                }
2✔
5417
        case GrantsFnCall:
2✔
5418
                return "grants"
2✔
5419
        }
5420

5421
        // not reachable
5422
        return ""
×
5423
}
5424

5425
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
25✔
5426
        if stmt.fnCall == nil {
25✔
5427
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
×
5428
        }
×
5429

5430
        switch strings.ToUpper(stmt.fnCall.fn) {
25✔
5431
        case DatabasesFnCall:
5✔
5432
                {
10✔
5433
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
5✔
5434
                }
5✔
5435
        case TablesFnCall:
5✔
5436
                {
10✔
5437
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
5✔
5438
                }
5✔
5439
        case TableFnCall:
×
5440
                {
×
5441
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
×
5442
                }
×
5443
        case UsersFnCall:
7✔
5444
                {
14✔
5445
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
7✔
5446
                }
7✔
5447
        case ColumnsFnCall:
3✔
5448
                {
6✔
5449
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
5450
                }
3✔
5451
        case IndexesFnCall:
3✔
5452
                {
6✔
5453
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
3✔
5454
                }
3✔
5455
        case GrantsFnCall:
2✔
5456
                {
4✔
5457
                        return stmt.resolveListGrants(ctx, tx, params, scanSpecs)
2✔
5458
                }
2✔
5459
        }
5460

5461
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
×
5462
}
5463

5464
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5465
        if len(stmt.fnCall.params) > 0 {
5✔
5466
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
×
5467
        }
×
5468

5469
        cols := make([]ColDescriptor, 1)
5✔
5470
        cols[0] = ColDescriptor{
5✔
5471
                Column: "name",
5✔
5472
                Type:   VarcharType,
5✔
5473
        }
5✔
5474

5✔
5475
        var dbs []string
5✔
5476

5✔
5477
        if tx.engine.multidbHandler == nil {
6✔
5478
                return nil, ErrUnspecifiedMultiDBHandler
1✔
5479
        } else {
5✔
5480
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
4✔
5481
                if err != nil {
4✔
5482
                        return nil, err
×
5483
                }
×
5484
        }
5485

5486
        values := make([][]ValueExp, len(dbs))
4✔
5487

4✔
5488
        for i, db := range dbs {
12✔
5489
                values[i] = []ValueExp{&Varchar{val: db}}
8✔
5490
        }
8✔
5491

5492
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
4✔
5493
}
5494

5495
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5496
        if len(stmt.fnCall.params) > 0 {
5✔
5497
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
5498
        }
×
5499

5500
        cols := make([]ColDescriptor, 1)
5✔
5501
        cols[0] = ColDescriptor{
5✔
5502
                Column: "name",
5✔
5503
                Type:   VarcharType,
5✔
5504
        }
5✔
5505

5✔
5506
        tables := tx.catalog.GetTables()
5✔
5507

5✔
5508
        values := make([][]ValueExp, len(tables))
5✔
5509

5✔
5510
        for i, t := range tables {
14✔
5511
                values[i] = []ValueExp{&Varchar{val: t.name}}
9✔
5512
        }
9✔
5513

5514
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5515
}
5516

5517
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
×
5518
        cols := []ColDescriptor{
×
5519
                {
×
5520
                        Column: "column_name",
×
5521
                        Type:   VarcharType,
×
5522
                },
×
5523
                {
×
5524
                        Column: "type_name",
×
5525
                        Type:   VarcharType,
×
5526
                },
×
5527
                {
×
5528
                        Column: "is_nullable",
×
5529
                        Type:   BooleanType,
×
5530
                },
×
5531
                {
×
5532
                        Column: "is_indexed",
×
5533
                        Type:   VarcharType,
×
5534
                },
×
5535
                {
×
5536
                        Column: "is_auto_increment",
×
5537
                        Type:   BooleanType,
×
5538
                },
×
5539
                {
×
5540
                        Column: "is_unique",
×
5541
                        Type:   BooleanType,
×
5542
                },
×
5543
        }
×
5544

×
5545
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
5546
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
5547
        if err != nil {
×
5548
                return nil, err
×
5549
        }
×
5550

5551
        values := make([][]ValueExp, len(table.cols))
×
5552

×
5553
        for i, c := range table.cols {
×
5554
                index := "NO"
×
5555

×
5556
                indexed, err := table.IsIndexed(c.Name())
×
5557
                if err != nil {
×
5558
                        return nil, err
×
5559
                }
×
5560
                if indexed {
×
5561
                        index = "YES"
×
5562
                }
×
5563

5564
                if table.PrimaryIndex().IncludesCol(c.ID()) {
×
5565
                        index = "PRIMARY KEY"
×
5566
                }
×
5567

5568
                var unique bool
×
5569
                for _, index := range table.GetIndexesByColID(c.ID()) {
×
5570
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
5571
                                unique = true
×
5572
                                break
×
5573
                        }
5574
                }
5575

5576
                var maxLen string
×
5577

×
5578
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
5579
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
×
5580
                }
×
5581

5582
                values[i] = []ValueExp{
×
5583
                        &Varchar{val: c.colName},
×
5584
                        &Varchar{val: c.Type() + maxLen},
×
5585
                        &Bool{val: c.IsNullable()},
×
5586
                        &Varchar{val: index},
×
5587
                        &Bool{val: c.IsAutoIncremental()},
×
5588
                        &Bool{val: unique},
×
5589
                }
×
5590
        }
5591

5592
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
×
5593
}
5594

5595
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
7✔
5596
        if len(stmt.fnCall.params) > 0 {
7✔
5597
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
×
5598
        }
×
5599

5600
        cols := []ColDescriptor{
7✔
5601
                {
7✔
5602
                        Column: "name",
7✔
5603
                        Type:   VarcharType,
7✔
5604
                },
7✔
5605
                {
7✔
5606
                        Column: "permission",
7✔
5607
                        Type:   VarcharType,
7✔
5608
                },
7✔
5609
        }
7✔
5610

7✔
5611
        users, err := tx.ListUsers(ctx)
7✔
5612
        if err != nil {
7✔
5613
                return nil, err
×
5614
        }
×
5615

5616
        values := make([][]ValueExp, len(users))
7✔
5617
        for i, user := range users {
23✔
5618
                perm := user.Permission()
16✔
5619

16✔
5620
                values[i] = []ValueExp{
16✔
5621
                        &Varchar{val: user.Username()},
16✔
5622
                        &Varchar{val: perm},
16✔
5623
                }
16✔
5624
        }
16✔
5625
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
7✔
5626
}
5627

5628
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5629
        if len(stmt.fnCall.params) != 1 {
3✔
5630
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
×
5631
        }
×
5632

5633
        cols := []ColDescriptor{
3✔
5634
                {
3✔
5635
                        Column: "table",
3✔
5636
                        Type:   VarcharType,
3✔
5637
                },
3✔
5638
                {
3✔
5639
                        Column: "name",
3✔
5640
                        Type:   VarcharType,
3✔
5641
                },
3✔
5642
                {
3✔
5643
                        Column: "type",
3✔
5644
                        Type:   VarcharType,
3✔
5645
                },
3✔
5646
                {
3✔
5647
                        Column: "max_length",
3✔
5648
                        Type:   IntegerType,
3✔
5649
                },
3✔
5650
                {
3✔
5651
                        Column: "nullable",
3✔
5652
                        Type:   BooleanType,
3✔
5653
                },
3✔
5654
                {
3✔
5655
                        Column: "auto_increment",
3✔
5656
                        Type:   BooleanType,
3✔
5657
                },
3✔
5658
                {
3✔
5659
                        Column: "indexed",
3✔
5660
                        Type:   BooleanType,
3✔
5661
                },
3✔
5662
                {
3✔
5663
                        Column: "primary",
3✔
5664
                        Type:   BooleanType,
3✔
5665
                },
3✔
5666
                {
3✔
5667
                        Column: "unique",
3✔
5668
                        Type:   BooleanType,
3✔
5669
                },
3✔
5670
        }
3✔
5671

3✔
5672
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5673
        if err != nil {
3✔
5674
                return nil, err
×
5675
        }
×
5676

5677
        tableName, err := val.reduce(tx, nil, "")
3✔
5678
        if err != nil {
3✔
5679
                return nil, err
×
5680
        }
×
5681

5682
        if tableName.Type() != VarcharType {
3✔
5683
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5684
        }
×
5685

5686
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5687
        if err != nil {
3✔
5688
                return nil, err
×
5689
        }
×
5690

5691
        values := make([][]ValueExp, len(table.cols))
3✔
5692

3✔
5693
        for i, c := range table.cols {
11✔
5694
                indexed, err := table.IsIndexed(c.Name())
8✔
5695
                if err != nil {
8✔
5696
                        return nil, err
×
5697
                }
×
5698

5699
                var unique bool
8✔
5700
                for _, index := range table.indexesByColID[c.id] {
16✔
5701
                        if index.IsUnique() && len(index.Cols()) == 1 {
11✔
5702
                                unique = true
3✔
5703
                                break
3✔
5704
                        }
5705
                }
5706

5707
                values[i] = []ValueExp{
8✔
5708
                        &Varchar{val: table.name},
8✔
5709
                        &Varchar{val: c.colName},
8✔
5710
                        &Varchar{val: c.colType},
8✔
5711
                        &Integer{val: int64(c.MaxLen())},
8✔
5712
                        &Bool{val: c.IsNullable()},
8✔
5713
                        &Bool{val: c.autoIncrement},
8✔
5714
                        &Bool{val: indexed},
8✔
5715
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
8✔
5716
                        &Bool{val: unique},
8✔
5717
                }
8✔
5718
        }
5719

5720
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5721
}
5722

5723
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5724
        if len(stmt.fnCall.params) != 1 {
3✔
5725
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
×
5726
        }
×
5727

5728
        cols := []ColDescriptor{
3✔
5729
                {
3✔
5730
                        Column: "table",
3✔
5731
                        Type:   VarcharType,
3✔
5732
                },
3✔
5733
                {
3✔
5734
                        Column: "name",
3✔
5735
                        Type:   VarcharType,
3✔
5736
                },
3✔
5737
                {
3✔
5738
                        Column: "unique",
3✔
5739
                        Type:   BooleanType,
3✔
5740
                },
3✔
5741
                {
3✔
5742
                        Column: "primary",
3✔
5743
                        Type:   BooleanType,
3✔
5744
                },
3✔
5745
        }
3✔
5746

3✔
5747
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5748
        if err != nil {
3✔
5749
                return nil, err
×
5750
        }
×
5751

5752
        tableName, err := val.reduce(tx, nil, "")
3✔
5753
        if err != nil {
3✔
5754
                return nil, err
×
5755
        }
×
5756

5757
        if tableName.Type() != VarcharType {
3✔
5758
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5759
        }
×
5760

5761
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5762
        if err != nil {
3✔
5763
                return nil, err
×
5764
        }
×
5765

5766
        values := make([][]ValueExp, len(table.indexes))
3✔
5767

3✔
5768
        for i, index := range table.indexes {
10✔
5769
                values[i] = []ValueExp{
7✔
5770
                        &Varchar{val: table.name},
7✔
5771
                        &Varchar{val: index.Name()},
7✔
5772
                        &Bool{val: index.unique},
7✔
5773
                        &Bool{val: index.IsPrimary()},
7✔
5774
                }
7✔
5775
        }
7✔
5776

5777
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5778
}
5779

5780
func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
2✔
5781
        if len(stmt.fnCall.params) > 1 {
2✔
5782
                return nil, fmt.Errorf("%w: function '%s' expect at most one parameter of type %s", ErrIllegalArguments, GrantsFnCall, VarcharType)
×
5783
        }
×
5784

5785
        var username string
2✔
5786
        if len(stmt.fnCall.params) == 1 {
3✔
5787
                val, err := stmt.fnCall.params[0].substitute(params)
1✔
5788
                if err != nil {
1✔
5789
                        return nil, err
×
5790
                }
×
5791

5792
                userVal, err := val.reduce(tx, nil, "")
1✔
5793
                if err != nil {
1✔
5794
                        return nil, err
×
5795
                }
×
5796

5797
                if userVal.Type() != VarcharType {
1✔
5798
                        return nil, fmt.Errorf("%w: expected '%s' for username but type '%s' given instead", ErrIllegalArguments, VarcharType, userVal.Type())
×
5799
                }
×
5800
                username, _ = userVal.RawValue().(string)
1✔
5801
        }
5802

5803
        cols := []ColDescriptor{
2✔
5804
                {
2✔
5805
                        Column: "user",
2✔
5806
                        Type:   VarcharType,
2✔
5807
                },
2✔
5808
                {
2✔
5809
                        Column: "privilege",
2✔
5810
                        Type:   VarcharType,
2✔
5811
                },
2✔
5812
        }
2✔
5813

2✔
5814
        var err error
2✔
5815
        var users []User
2✔
5816

2✔
5817
        if tx.engine.multidbHandler == nil {
2✔
5818
                return nil, ErrUnspecifiedMultiDBHandler
×
5819
        } else {
2✔
5820
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
2✔
5821
                if err != nil {
2✔
5822
                        return nil, err
×
5823
                }
×
5824
        }
5825

5826
        values := make([][]ValueExp, 0, len(users))
2✔
5827

2✔
5828
        for _, user := range users {
4✔
5829
                if username == "" || user.Username() == username {
4✔
5830
                        for _, p := range user.SQLPrivileges() {
6✔
5831
                                values = append(values, []ValueExp{
4✔
5832
                                        &Varchar{val: user.Username()},
4✔
5833
                                        &Varchar{val: string(p)},
4✔
5834
                                })
4✔
5835
                        }
4✔
5836
                }
5837
        }
5838

5839
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
2✔
5840
}
5841

5842
// DropTableStmt represents a statement to delete a table.
5843
type DropTableStmt struct {
5844
        table string
5845
}
5846

5847
func NewDropTableStmt(table string) *DropTableStmt {
6✔
5848
        return &DropTableStmt{table: table}
6✔
5849
}
6✔
5850

5851
func (stmt *DropTableStmt) readOnly() bool {
1✔
5852
        return false
1✔
5853
}
1✔
5854

5855
func (stmt *DropTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
5856
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5857
}
1✔
5858

5859
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5860
        return nil
1✔
5861
}
1✔
5862

5863
/*
5864
Exec executes the delete table statement.
5865
It the table exists, if not it does nothing.
5866
If the table exists, it deletes all the indexes and the table itself.
5867
Note that this is a soft delete of the index and table key,
5868
the data is not deleted, but the metadata is updated.
5869
*/
5870
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
7✔
5871
        if !tx.catalog.ExistTable(stmt.table) {
8✔
5872
                return nil, ErrTableDoesNotExist
1✔
5873
        }
1✔
5874

5875
        table, err := tx.catalog.GetTableByName(stmt.table)
6✔
5876
        if err != nil {
6✔
5877
                return nil, err
×
5878
        }
×
5879

5880
        // delete table
5881
        mappedKey := MapKey(
6✔
5882
                tx.sqlPrefix(),
6✔
5883
                catalogTablePrefix,
6✔
5884
                EncodeID(DatabaseID),
6✔
5885
                EncodeID(table.id),
6✔
5886
        )
6✔
5887
        err = tx.delete(ctx, mappedKey)
6✔
5888
        if err != nil {
6✔
5889
                return nil, err
×
5890
        }
×
5891

5892
        // delete columns
5893
        cols := table.ColumnsByID()
6✔
5894
        for _, col := range cols {
26✔
5895
                mappedKey := MapKey(
20✔
5896
                        tx.sqlPrefix(),
20✔
5897
                        catalogColumnPrefix,
20✔
5898
                        EncodeID(DatabaseID),
20✔
5899
                        EncodeID(col.table.id),
20✔
5900
                        EncodeID(col.id),
20✔
5901
                        []byte(col.colType),
20✔
5902
                )
20✔
5903
                err = tx.delete(ctx, mappedKey)
20✔
5904
                if err != nil {
20✔
5905
                        return nil, err
×
5906
                }
×
5907
        }
5908

5909
        // delete checks
5910
        for name := range table.checkConstraints {
6✔
5911
                key := MapKey(
×
5912
                        tx.sqlPrefix(),
×
5913
                        catalogCheckPrefix,
×
5914
                        EncodeID(DatabaseID),
×
5915
                        EncodeID(table.id),
×
5916
                        []byte(name),
×
5917
                )
×
5918

×
5919
                if err := tx.delete(ctx, key); err != nil {
×
5920
                        return nil, err
×
5921
                }
×
5922
        }
5923

5924
        // delete indexes
5925
        for _, index := range table.indexes {
13✔
5926
                mappedKey := MapKey(
7✔
5927
                        tx.sqlPrefix(),
7✔
5928
                        catalogIndexPrefix,
7✔
5929
                        EncodeID(DatabaseID),
7✔
5930
                        EncodeID(table.id),
7✔
5931
                        EncodeID(index.id),
7✔
5932
                )
7✔
5933
                err = tx.delete(ctx, mappedKey)
7✔
5934
                if err != nil {
7✔
5935
                        return nil, err
×
5936
                }
×
5937

5938
                indexKey := MapKey(
7✔
5939
                        tx.sqlPrefix(),
7✔
5940
                        MappedPrefix,
7✔
5941
                        EncodeID(table.id),
7✔
5942
                        EncodeID(index.id),
7✔
5943
                )
7✔
5944
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
14✔
5945
                        return sqlTx.engine.store.DeleteIndex(indexKey)
7✔
5946
                })
7✔
5947
                if err != nil {
7✔
5948
                        return nil, err
×
5949
                }
×
5950
        }
5951

5952
        err = tx.catalog.deleteTable(table)
6✔
5953
        if err != nil {
6✔
5954
                return nil, err
×
5955
        }
×
5956

5957
        tx.mutatedCatalog = true
6✔
5958

6✔
5959
        return tx, nil
6✔
5960
}
5961

5962
// DropIndexStmt represents a statement to delete a table.
5963
type DropIndexStmt struct {
5964
        table string
5965
        cols  []string
5966
}
5967

5968
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
4✔
5969
        return &DropIndexStmt{table: table, cols: cols}
4✔
5970
}
4✔
5971

5972
func (stmt *DropIndexStmt) readOnly() bool {
1✔
5973
        return false
1✔
5974
}
1✔
5975

5976
func (stmt *DropIndexStmt) requiredPrivileges() []SQLPrivilege {
1✔
5977
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5978
}
1✔
5979

5980
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5981
        return nil
1✔
5982
}
1✔
5983

5984
/*
5985
Exec executes the delete index statement.
5986
If the index exists, it deletes it. Note that this is a soft delete of the index
5987
the data is not deleted, but the metadata is updated.
5988
*/
5989
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
5990
        if !tx.catalog.ExistTable(stmt.table) {
7✔
5991
                return nil, ErrTableDoesNotExist
1✔
5992
        }
1✔
5993

5994
        table, err := tx.catalog.GetTableByName(stmt.table)
5✔
5995
        if err != nil {
5✔
5996
                return nil, err
×
5997
        }
×
5998

5999
        cols := make([]*Column, len(stmt.cols))
5✔
6000

5✔
6001
        for i, colName := range stmt.cols {
10✔
6002
                col, err := table.GetColumnByName(colName)
5✔
6003
                if err != nil {
5✔
6004
                        return nil, err
×
6005
                }
×
6006

6007
                cols[i] = col
5✔
6008
        }
6009

6010
        index, err := table.GetIndexByName(indexName(table.name, cols))
5✔
6011
        if err != nil {
5✔
6012
                return nil, err
×
6013
        }
×
6014

6015
        // delete index
6016
        mappedKey := MapKey(
5✔
6017
                tx.sqlPrefix(),
5✔
6018
                catalogIndexPrefix,
5✔
6019
                EncodeID(DatabaseID),
5✔
6020
                EncodeID(table.id),
5✔
6021
                EncodeID(index.id),
5✔
6022
        )
5✔
6023
        err = tx.delete(ctx, mappedKey)
5✔
6024
        if err != nil {
5✔
6025
                return nil, err
×
6026
        }
×
6027

6028
        indexKey := MapKey(
5✔
6029
                tx.sqlPrefix(),
5✔
6030
                MappedPrefix,
5✔
6031
                EncodeID(table.id),
5✔
6032
                EncodeID(index.id),
5✔
6033
        )
5✔
6034

5✔
6035
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
9✔
6036
                return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
6037
        })
4✔
6038
        if err != nil {
5✔
6039
                return nil, err
×
6040
        }
×
6041

6042
        err = table.deleteIndex(index)
5✔
6043
        if err != nil {
6✔
6044
                return nil, err
1✔
6045
        }
1✔
6046

6047
        tx.mutatedCatalog = true
4✔
6048

4✔
6049
        return tx, nil
4✔
6050
}
6051

6052
type SQLPrivilege string
6053

6054
const (
6055
        SQLPrivilegeSelect SQLPrivilege = "SELECT"
6056
        SQLPrivilegeCreate SQLPrivilege = "CREATE"
6057
        SQLPrivilegeInsert SQLPrivilege = "INSERT"
6058
        SQLPrivilegeUpdate SQLPrivilege = "UPDATE"
6059
        SQLPrivilegeDelete SQLPrivilege = "DELETE"
6060
        SQLPrivilegeDrop   SQLPrivilege = "DROP"
6061
        SQLPrivilegeAlter  SQLPrivilege = "ALTER"
6062
)
6063

6064
var allPrivileges = []SQLPrivilege{
6065
        SQLPrivilegeSelect,
6066
        SQLPrivilegeCreate,
6067
        SQLPrivilegeInsert,
6068
        SQLPrivilegeUpdate,
6069
        SQLPrivilegeDelete,
6070
        SQLPrivilegeDrop,
6071
        SQLPrivilegeAlter,
6072
}
6073

6074
func DefaultSQLPrivilegesForPermission(p Permission) []SQLPrivilege {
295✔
6075
        switch p {
295✔
6076
        case PermissionSysAdmin, PermissionAdmin, PermissionReadWrite:
284✔
6077
                return allPrivileges
284✔
6078
        case PermissionReadOnly:
11✔
6079
                return []SQLPrivilege{SQLPrivilegeSelect}
11✔
6080
        }
6081
        return nil
×
6082
}
6083

6084
type AlterPrivilegesStmt struct {
6085
        database   string
6086
        user       string
6087
        privileges []SQLPrivilege
6088
        isGrant    bool
6089
}
6090

6091
func (stmt *AlterPrivilegesStmt) readOnly() bool {
2✔
6092
        return false
2✔
6093
}
2✔
6094

6095
func (stmt *AlterPrivilegesStmt) requiredPrivileges() []SQLPrivilege {
2✔
6096
        return nil
2✔
6097
}
2✔
6098

6099
func (stmt *AlterPrivilegesStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2✔
6100
        if tx.IsExplicitCloseRequired() {
3✔
6101
                return nil, fmt.Errorf("%w: user privileges modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
6102
        }
1✔
6103

6104
        if tx.engine.multidbHandler == nil {
1✔
6105
                return nil, ErrUnspecifiedMultiDBHandler
×
6106
        }
×
6107

6108
        var err error
1✔
6109
        if stmt.isGrant {
1✔
6110
                err = tx.engine.multidbHandler.GrantSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
×
6111
        } else {
1✔
6112
                err = tx.engine.multidbHandler.RevokeSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
1✔
6113
        }
1✔
6114
        return nil, err
1✔
6115
}
6116

6117
func (stmt *AlterPrivilegesStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
6118
        return nil
1✔
6119
}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc