• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 18612223019

18 Oct 2025 06:47AM UTC coverage: 89.265%. Remained the same
18612223019

push

gh-ci

ostafen
chore(embedded/sql): Implement EXTRACT FROM TIMESTAMP expressions

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>

421 of 455 new or added lines in 3 files covered. (92.53%)

1 existing line in 1 file now uncovered.

37943 of 42506 relevant lines covered (89.27%)

150551.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.2
/embedded/sql/stmt.go
1
/*
2
Copyright 2025 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix          = "CTL."
38
        catalogTablePrefix     = "CTL.TABLE."     // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix    = "CTL.COLUMN."    // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogIndexPrefix     = "CTL.INDEX."     // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
41
        catalogCheckPrefix     = "CTL.CHECK."     // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
42
        catalogPrivilegePrefix = "CTL.PRIVILEGE." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
43

44
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
45
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
46
)
47

48
const (
49
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
50
        PKIndexID  = uint32(0)
51
)
52

53
const (
54
        nullableFlag      byte = 1 << iota
55
        autoIncrementFlag byte = 1 << iota
56
)
57

58
const (
59
        revCol        = "_rev"
60
        txMetadataCol = "_tx_metadata"
61
)
62

63
var reservedColumns = map[string]struct{}{
64
        revCol:        {},
65
        txMetadataCol: {},
66
}
67

68
func isReservedCol(col string) bool {
15,676✔
69
        _, ok := reservedColumns[col]
15,676✔
70
        return ok
15,676✔
71
}
15,676✔
72

73
type SQLValueType = string
74

75
const (
76
        IntegerType   SQLValueType = "INTEGER"
77
        BooleanType   SQLValueType = "BOOLEAN"
78
        VarcharType   SQLValueType = "VARCHAR"
79
        UUIDType      SQLValueType = "UUID"
80
        BLOBType      SQLValueType = "BLOB"
81
        Float64Type   SQLValueType = "FLOAT"
82
        TimestampType SQLValueType = "TIMESTAMP"
83
        AnyType       SQLValueType = "ANY"
84
        JSONType      SQLValueType = "JSON"
85
)
86

87
func IsNumericType(t SQLValueType) bool {
217✔
88
        return t == IntegerType || t == Float64Type
217✔
89
}
217✔
90

91
type Permission = string
92

93
const (
94
        PermissionReadOnly  Permission = "READ"
95
        PermissionReadWrite Permission = "READWRITE"
96
        PermissionAdmin     Permission = "ADMIN"
97
        PermissionSysAdmin  Permission = "SYSADMIN"
98
)
99

100
func PermissionFromCode(code uint32) Permission {
629✔
101
        switch code {
629✔
102
        case 1:
16✔
103
                {
32✔
104
                        return PermissionReadOnly
16✔
105
                }
16✔
106
        case 2:
10✔
107
                {
20✔
108
                        return PermissionReadWrite
10✔
109
                }
10✔
110
        case 254:
20✔
111
                {
40✔
112
                        return PermissionAdmin
20✔
113
                }
20✔
114
        }
115
        return PermissionSysAdmin
583✔
116
}
117

118
type AggregateFn = string
119

120
const (
121
        COUNT AggregateFn = "COUNT"
122
        SUM   AggregateFn = "SUM"
123
        MAX   AggregateFn = "MAX"
124
        MIN   AggregateFn = "MIN"
125
        AVG   AggregateFn = "AVG"
126
)
127

128
type CmpOperator = int
129

130
const (
131
        EQ CmpOperator = iota
132
        NE
133
        LT
134
        LE
135
        GT
136
        GE
137
)
138

139
func CmpOperatorToString(op CmpOperator) string {
20✔
140
        switch op {
20✔
141
        case EQ:
8✔
142
                return "="
8✔
143
        case NE:
2✔
144
                return "!="
2✔
145
        case LT:
1✔
146
                return "<"
1✔
147
        case LE:
3✔
148
                return "<="
3✔
149
        case GT:
1✔
150
                return ">"
1✔
151
        case GE:
5✔
152
                return ">="
5✔
153
        }
154
        return ""
×
155
}
156

157
type LogicOperator = int
158

159
const (
160
        And LogicOperator = iota
161
        Or
162
)
163

164
func LogicOperatorToString(op LogicOperator) string {
31✔
165
        if op == And {
46✔
166
                return "AND"
15✔
167
        }
15✔
168
        return "OR"
16✔
169
}
170

171
type NumOperator = int
172

173
const (
174
        ADDOP NumOperator = iota
175
        SUBSOP
176
        DIVOP
177
        MULTOP
178
        MODOP
179
)
180

181
func NumOperatorString(op NumOperator) string {
18✔
182
        switch op {
18✔
183
        case ADDOP:
6✔
184
                return "+"
6✔
185
        case SUBSOP:
1✔
186
                return "-"
1✔
187
        case DIVOP:
5✔
188
                return "/"
5✔
189
        case MULTOP:
5✔
190
                return "*"
5✔
191
        case MODOP:
1✔
192
                return "%"
1✔
193
        }
194
        return ""
×
195
}
196

197
type JoinType = int
198

199
const (
200
        InnerJoin JoinType = iota
201
        LeftJoin
202
        RightJoin
203
)
204

205
type SQLStmt interface {
206
        readOnly() bool
207
        requiredPrivileges() []SQLPrivilege
208
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
209
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
210
}
211

212
type BeginTransactionStmt struct {
213
}
214

215
func (stmt *BeginTransactionStmt) readOnly() bool {
9✔
216
        return true
9✔
217
}
9✔
218

219
func (stmt *BeginTransactionStmt) requiredPrivileges() []SQLPrivilege {
9✔
220
        return nil
9✔
221
}
9✔
222

223
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
224
        return nil
3✔
225
}
3✔
226

227
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
61✔
228
        if tx.IsExplicitCloseRequired() {
62✔
229
                return nil, ErrNestedTxNotSupported
1✔
230
        }
1✔
231

232
        err := tx.RequireExplicitClose()
60✔
233
        if err == nil {
119✔
234
                // current tx can be reused as no changes were already made
59✔
235
                return tx, nil
59✔
236
        }
59✔
237

238
        // commit current transaction and start a fresh one
239

240
        err = tx.Commit(ctx)
1✔
241
        if err != nil {
1✔
242
                return nil, err
×
243
        }
×
244

245
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
246
}
247

248
type CommitStmt struct {
249
}
250

251
func (stmt *CommitStmt) readOnly() bool {
113✔
252
        return true
113✔
253
}
113✔
254

255
func (stmt *CommitStmt) requiredPrivileges() []SQLPrivilege {
113✔
256
        return nil
113✔
257
}
113✔
258

259
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
260
        return nil
1✔
261
}
1✔
262

263
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
264
        if !tx.IsExplicitCloseRequired() {
160✔
265
                return nil, ErrNoOngoingTx
1✔
266
        }
1✔
267

268
        return nil, tx.Commit(ctx)
158✔
269
}
270

271
type RollbackStmt struct {
272
}
273

274
func (stmt *RollbackStmt) readOnly() bool {
1✔
275
        return true
1✔
276
}
1✔
277

278
func (stmt *RollbackStmt) requiredPrivileges() []SQLPrivilege {
1✔
279
        return nil
1✔
280
}
1✔
281

282
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
283
        return nil
1✔
284
}
1✔
285

286
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
287
        if !tx.IsExplicitCloseRequired() {
5✔
288
                return nil, ErrNoOngoingTx
1✔
289
        }
1✔
290

291
        return nil, tx.Cancel()
3✔
292
}
293

294
type CreateDatabaseStmt struct {
295
        DB          string
296
        ifNotExists bool
297
}
298

299
func (stmt *CreateDatabaseStmt) readOnly() bool {
14✔
300
        return false
14✔
301
}
14✔
302

303
func (stmt *CreateDatabaseStmt) requiredPrivileges() []SQLPrivilege {
14✔
304
        return []SQLPrivilege{SQLPrivilegeCreate}
14✔
305
}
14✔
306

307
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
308
        return nil
4✔
309
}
4✔
310

311
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
312
        if tx.IsExplicitCloseRequired() {
16✔
313
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
314
        }
1✔
315

316
        if tx.engine.multidbHandler == nil {
16✔
317
                return nil, ErrUnspecifiedMultiDBHandler
2✔
318
        }
2✔
319

320
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
321
}
322

323
type UseDatabaseStmt struct {
324
        DB string
325
}
326

327
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
328
        return nil
1✔
329
}
1✔
330

331
func (stmt *UseDatabaseStmt) readOnly() bool {
9✔
332
        return true
9✔
333
}
9✔
334

335
func (stmt *UseDatabaseStmt) requiredPrivileges() []SQLPrivilege {
9✔
336
        return nil
9✔
337
}
9✔
338

339
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
340
        if tx.IsExplicitCloseRequired() {
10✔
341
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
342
        }
1✔
343

344
        if tx.engine.multidbHandler == nil {
9✔
345
                return nil, ErrUnspecifiedMultiDBHandler
1✔
346
        }
1✔
347

348
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
349
}
350

351
type UseSnapshotStmt struct {
352
        period period
353
}
354

355
func (stmt *UseSnapshotStmt) readOnly() bool {
1✔
356
        return true
1✔
357
}
1✔
358

359
func (stmt *UseSnapshotStmt) requiredPrivileges() []SQLPrivilege {
1✔
360
        return nil
1✔
361
}
1✔
362

363
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
364
        return nil
1✔
365
}
1✔
366

367
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
368
        return nil, ErrNoSupported
1✔
369
}
1✔
370

371
type CreateUserStmt struct {
372
        username   string
373
        password   string
374
        permission Permission
375
}
376

377
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
378
        return nil
1✔
379
}
1✔
380

381
func (stmt *CreateUserStmt) readOnly() bool {
5✔
382
        return false
5✔
383
}
5✔
384

385
func (stmt *CreateUserStmt) requiredPrivileges() []SQLPrivilege {
5✔
386
        return []SQLPrivilege{SQLPrivilegeCreate}
5✔
387
}
5✔
388

389
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
390
        if tx.IsExplicitCloseRequired() {
7✔
391
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
392
        }
1✔
393

394
        if tx.engine.multidbHandler == nil {
6✔
395
                return nil, ErrUnspecifiedMultiDBHandler
1✔
396
        }
1✔
397

398
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
399
}
400

401
type AlterUserStmt struct {
402
        username   string
403
        password   string
404
        permission Permission
405
}
406

407
func (stmt *AlterUserStmt) readOnly() bool {
4✔
408
        return false
4✔
409
}
4✔
410

411
func (stmt *AlterUserStmt) requiredPrivileges() []SQLPrivilege {
4✔
412
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
413
}
4✔
414

415
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
416
        return nil
1✔
417
}
1✔
418

419
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
420
        if tx.IsExplicitCloseRequired() {
6✔
421
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
422
        }
1✔
423

424
        if tx.engine.multidbHandler == nil {
5✔
425
                return nil, ErrUnspecifiedMultiDBHandler
1✔
426
        }
1✔
427

428
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
429
}
430

431
type DropUserStmt struct {
432
        username string
433
}
434

435
func (stmt *DropUserStmt) readOnly() bool {
3✔
436
        return false
3✔
437
}
3✔
438

439
func (stmt *DropUserStmt) requiredPrivileges() []SQLPrivilege {
3✔
440
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
441
}
3✔
442

443
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
444
        return nil
1✔
445
}
1✔
446

447
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
448
        if tx.IsExplicitCloseRequired() {
4✔
449
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
450
        }
1✔
451

452
        if tx.engine.multidbHandler == nil {
3✔
453
                return nil, ErrUnspecifiedMultiDBHandler
1✔
454
        }
1✔
455

456
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
457
}
458

459
type TableElem interface{}
460

461
type CreateTableStmt struct {
462
        table       string
463
        ifNotExists bool
464
        colsSpec    []*ColSpec
465
        checks      []CheckConstraint
466
        pkColNames  PrimaryKeyConstraint
467
}
468

469
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
470
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
471
}
38✔
472

473
func (stmt *CreateTableStmt) readOnly() bool {
66✔
474
        return false
66✔
475
}
66✔
476

477
func (stmt *CreateTableStmt) requiredPrivileges() []SQLPrivilege {
64✔
478
        return []SQLPrivilege{SQLPrivilegeCreate}
64✔
479
}
64✔
480

481
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
482
        return nil
4✔
483
}
4✔
484

485
func zeroRow(tableName string, cols []*ColSpec) *Row {
244✔
486
        r := Row{
244✔
487
                ValuesByPosition: make([]TypedValue, len(cols)),
244✔
488
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
244✔
489
        }
244✔
490

244✔
491
        for i, col := range cols {
1,024✔
492
                v := zeroForType(col.colType)
780✔
493

780✔
494
                r.ValuesByPosition[i] = v
780✔
495
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
780✔
496
        }
780✔
497
        return &r
244✔
498
}
499

500
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
235✔
501
        if err := stmt.validatePrimaryKey(); err != nil {
238✔
502
                return nil, err
3✔
503
        }
3✔
504

505
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
233✔
506
                return tx, nil
1✔
507
        }
1✔
508

509
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
231✔
510
        for i, cs := range stmt.colsSpec {
951✔
511
                colSpecs[uint32(i)+1] = cs
720✔
512
        }
720✔
513

514
        row := zeroRow(stmt.table, stmt.colsSpec)
231✔
515
        for _, check := range stmt.checks {
240✔
516
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
517
                if err != nil {
11✔
518
                        return nil, err
2✔
519
                }
2✔
520

521
                if value.Type() != BooleanType {
7✔
522
                        return nil, ErrInvalidCheckConstraint
×
523
                }
×
524
        }
525

526
        nextUnnamedCheck := 0
229✔
527
        checks := make(map[string]CheckConstraint)
229✔
528
        for id, check := range stmt.checks {
236✔
529
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
530
                if check.name != "" {
9✔
531
                        name = check.name
2✔
532
                } else {
7✔
533
                        nextUnnamedCheck++
5✔
534
                }
5✔
535
                check.id = uint32(id)
7✔
536
                check.name = name
7✔
537
                checks[name] = check
7✔
538
        }
539

540
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
229✔
541
        if err != nil {
235✔
542
                return nil, err
6✔
543
        }
6✔
544

545
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.primaryKeyCols()}
223✔
546
        _, err = createIndexStmt.execAt(ctx, tx, params)
223✔
547
        if err != nil {
228✔
548
                return nil, err
5✔
549
        }
5✔
550

551
        for _, col := range table.cols {
919✔
552
                if col.autoIncrement {
778✔
553
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
78✔
554
                                return nil, ErrLimitedAutoIncrement
1✔
555
                        }
1✔
556
                }
557

558
                err := persistColumn(tx, col)
700✔
559
                if err != nil {
700✔
560
                        return nil, err
×
561
                }
×
562
        }
563

564
        for _, check := range checks {
224✔
565
                if err := persistCheck(tx, table, &check); err != nil {
7✔
566
                        return nil, err
×
567
                }
×
568
        }
569

570
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
217✔
571

217✔
572
        err = tx.set(mappedKey, nil, []byte(table.name))
217✔
573
        if err != nil {
217✔
574
                return nil, err
×
575
        }
×
576

577
        tx.mutatedCatalog = true
217✔
578

217✔
579
        return tx, nil
217✔
580
}
581

582
func (stmt *CreateTableStmt) validatePrimaryKey() error {
235✔
583
        n := 0
235✔
584
        for _, spec := range stmt.colsSpec {
962✔
585
                if spec.primaryKey {
732✔
586
                        n++
5✔
587
                }
5✔
588
        }
589

590
        if len(stmt.pkColNames) > 0 {
466✔
591
                n++
231✔
592
        }
231✔
593

594
        switch n {
235✔
595
        case 0:
1✔
596
                return ErrNoPrimaryKey
1✔
597
        case 1:
232✔
598
                return nil
232✔
599
        }
600
        return fmt.Errorf("\"%s\": %w", stmt.table, ErrMultiplePrimaryKeys)
2✔
601
}
602

603
func (stmt *CreateTableStmt) primaryKeyCols() []string {
223✔
604
        if len(stmt.pkColNames) > 0 {
444✔
605
                return stmt.pkColNames
221✔
606
        }
221✔
607

608
        for _, spec := range stmt.colsSpec {
4✔
609
                if spec.primaryKey {
4✔
610
                        return []string{spec.colName}
2✔
611
                }
2✔
612
        }
613
        return nil
×
614
}
615

616
func persistColumn(tx *SQLTx, col *Column) error {
720✔
617
        //{auto_incremental | nullable}{maxLen}{colNAME})
720✔
618
        v := make([]byte, 1+4+len(col.colName))
720✔
619

720✔
620
        if col.autoIncrement {
796✔
621
                v[0] = v[0] | autoIncrementFlag
76✔
622
        }
76✔
623

624
        if col.notNull {
767✔
625
                v[0] = v[0] | nullableFlag
47✔
626
        }
47✔
627

628
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
720✔
629

720✔
630
        copy(v[5:], []byte(col.Name()))
720✔
631

720✔
632
        mappedKey := MapKey(
720✔
633
                tx.sqlPrefix(),
720✔
634
                catalogColumnPrefix,
720✔
635
                EncodeID(DatabaseID),
720✔
636
                EncodeID(col.table.id),
720✔
637
                EncodeID(col.id),
720✔
638
                []byte(col.colType),
720✔
639
        )
720✔
640

720✔
641
        return tx.set(mappedKey, nil, v)
720✔
642
}
643

644
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
645
        mappedKey := MapKey(
7✔
646
                tx.sqlPrefix(),
7✔
647
                catalogCheckPrefix,
7✔
648
                EncodeID(DatabaseID),
7✔
649
                EncodeID(table.id),
7✔
650
                EncodeID(check.id),
7✔
651
        )
7✔
652

7✔
653
        name := check.name
7✔
654
        expText := check.exp.String()
7✔
655

7✔
656
        val := make([]byte, 2+len(name)+len(expText))
7✔
657

7✔
658
        if len(name) > 256 {
7✔
659
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
660
        }
×
661

662
        val[0] = byte(len(name)) - 1
7✔
663

7✔
664
        copy(val[1:], []byte(name))
7✔
665
        copy(val[1+len(name):], []byte(expText))
7✔
666

7✔
667
        return tx.set(mappedKey, nil, val)
7✔
668
}
669

670
type ColSpec struct {
671
        colName       string
672
        colType       SQLValueType
673
        maxLen        int
674
        autoIncrement bool
675
        notNull       bool
676
        primaryKey    bool
677
}
678

679
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
680
        return &ColSpec{
188✔
681
                colName:       name,
188✔
682
                colType:       colType,
188✔
683
                maxLen:        maxLen,
188✔
684
                autoIncrement: autoIncrement,
188✔
685
                notNull:       notNull,
188✔
686
        }
188✔
687
}
188✔
688

689
type CreateIndexStmt struct {
690
        unique      bool
691
        ifNotExists bool
692
        table       string
693
        cols        []string
694
}
695

696
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
697
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
698
}
72✔
699

700
func (stmt *CreateIndexStmt) readOnly() bool {
7✔
701
        return false
7✔
702
}
7✔
703

704
func (stmt *CreateIndexStmt) requiredPrivileges() []SQLPrivilege {
7✔
705
        return []SQLPrivilege{SQLPrivilegeCreate}
7✔
706
}
7✔
707

708
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
709
        return nil
1✔
710
}
1✔
711

712
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
374✔
713
        if len(stmt.cols) < 1 {
375✔
714
                return nil, ErrIllegalArguments
1✔
715
        }
1✔
716

717
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
374✔
718
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
719
        }
1✔
720

721
        table, err := tx.catalog.GetTableByName(stmt.table)
372✔
722
        if err != nil {
374✔
723
                return nil, err
2✔
724
        }
2✔
725

726
        colIDs := make([]uint32, len(stmt.cols))
370✔
727

370✔
728
        indexKeyLen := 0
370✔
729

370✔
730
        for i, colName := range stmt.cols {
767✔
731
                col, err := table.GetColumnByName(colName)
397✔
732
                if err != nil {
402✔
733
                        return nil, err
5✔
734
                }
5✔
735

736
                if col.Type() == JSONType {
394✔
737
                        return nil, ErrCannotIndexJson
2✔
738
                }
2✔
739

740
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
392✔
741
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
742
                }
2✔
743

744
                indexKeyLen += col.MaxLen()
388✔
745

388✔
746
                colIDs[i] = col.id
388✔
747
        }
748

749
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
361✔
750
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
751
        }
×
752

753
        if stmt.unique && table.primaryIndex != nil {
381✔
754
                // check table is empty
20✔
755
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
756
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
757
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
758
                        return nil, ErrTableDoesNotExist
×
759
                }
×
760
                if err == nil {
21✔
761
                        return nil, ErrLimitedIndexCreation
1✔
762
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
763
                        return nil, err
×
764
                }
×
765
        }
766

767
        index, err := table.newIndex(stmt.unique, colIDs)
360✔
768
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
362✔
769
                return tx, nil
2✔
770
        }
2✔
771
        if err != nil {
362✔
772
                return nil, err
4✔
773
        }
4✔
774

775
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
776
        // TODO: currently only ASC order is supported
777
        colSpecLen := EncIDLen + 1
354✔
778

354✔
779
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
354✔
780

354✔
781
        if index.IsUnique() {
590✔
782
                encodedValues[0] = 1
236✔
783
        }
236✔
784

785
        for i, col := range index.cols {
735✔
786
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
381✔
787
        }
381✔
788

789
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
354✔
790

354✔
791
        err = tx.set(mappedKey, nil, encodedValues)
354✔
792
        if err != nil {
354✔
793
                return nil, err
×
794
        }
×
795

796
        tx.mutatedCatalog = true
354✔
797

354✔
798
        return tx, nil
354✔
799
}
800

801
type AddColumnStmt struct {
802
        table   string
803
        colSpec *ColSpec
804
}
805

806
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
807
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
808
}
6✔
809

810
func (stmt *AddColumnStmt) readOnly() bool {
4✔
811
        return false
4✔
812
}
4✔
813

814
func (stmt *AddColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
815
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
816
}
4✔
817

818
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
819
        return nil
1✔
820
}
1✔
821

822
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
823
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
824
        if err != nil {
20✔
825
                return nil, err
1✔
826
        }
1✔
827

828
        col, err := table.newColumn(stmt.colSpec)
18✔
829
        if err != nil {
24✔
830
                return nil, err
6✔
831
        }
6✔
832

833
        err = persistColumn(tx, col)
12✔
834
        if err != nil {
12✔
835
                return nil, err
×
836
        }
×
837

838
        tx.mutatedCatalog = true
12✔
839

12✔
840
        return tx, nil
12✔
841
}
842

843
type RenameTableStmt struct {
844
        oldName string
845
        newName string
846
}
847

848
func (stmt *RenameTableStmt) readOnly() bool {
1✔
849
        return false
1✔
850
}
1✔
851

852
func (stmt *RenameTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
853
        return []SQLPrivilege{SQLPrivilegeAlter}
1✔
854
}
1✔
855

856
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
857
        return nil
1✔
858
}
1✔
859

860
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
861
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
862
        if err != nil {
10✔
863
                return nil, err
4✔
864
        }
4✔
865

866
        // update table name
867
        mappedKey := MapKey(
2✔
868
                tx.sqlPrefix(),
2✔
869
                catalogTablePrefix,
2✔
870
                EncodeID(DatabaseID),
2✔
871
                EncodeID(table.id),
2✔
872
        )
2✔
873
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
874
        if err != nil {
2✔
875
                return nil, err
×
876
        }
×
877

878
        tx.mutatedCatalog = true
2✔
879

2✔
880
        return tx, nil
2✔
881
}
882

883
type RenameColumnStmt struct {
884
        table   string
885
        oldName string
886
        newName string
887
}
888

889
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
890
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
891
}
3✔
892

893
func (stmt *RenameColumnStmt) readOnly() bool {
4✔
894
        return false
4✔
895
}
4✔
896

897
func (stmt *RenameColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
898
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
899
}
4✔
900

901
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
902
        return nil
1✔
903
}
1✔
904

905
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
906
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
907
        if err != nil {
11✔
908
                return nil, err
1✔
909
        }
1✔
910

911
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
912
        if err != nil {
12✔
913
                return nil, err
3✔
914
        }
3✔
915

916
        err = persistColumn(tx, col)
6✔
917
        if err != nil {
6✔
918
                return nil, err
×
919
        }
×
920

921
        tx.mutatedCatalog = true
6✔
922

6✔
923
        return tx, nil
6✔
924
}
925

926
type DropColumnStmt struct {
927
        table   string
928
        colName string
929
}
930

931
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
932
        return &DropColumnStmt{table: table, colName: colName}
8✔
933
}
8✔
934

935
func (stmt *DropColumnStmt) readOnly() bool {
2✔
936
        return false
2✔
937
}
2✔
938

939
func (stmt *DropColumnStmt) requiredPrivileges() []SQLPrivilege {
2✔
940
        return []SQLPrivilege{SQLPrivilegeDrop}
2✔
941
}
2✔
942

943
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
944
        return nil
1✔
945
}
1✔
946

947
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
948
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
949
        if err != nil {
21✔
950
                return nil, err
2✔
951
        }
2✔
952

953
        col, err := table.GetColumnByName(stmt.colName)
17✔
954
        if err != nil {
21✔
955
                return nil, err
4✔
956
        }
4✔
957

958
        err = canDropColumn(tx, table, col)
13✔
959
        if err != nil {
14✔
960
                return nil, err
1✔
961
        }
1✔
962

963
        err = table.deleteColumn(col)
12✔
964
        if err != nil {
16✔
965
                return nil, err
4✔
966
        }
4✔
967

968
        err = persistColumnDeletion(ctx, tx, col)
8✔
969
        if err != nil {
8✔
970
                return nil, err
×
971
        }
×
972

973
        tx.mutatedCatalog = true
8✔
974

8✔
975
        return tx, nil
8✔
976
}
977

978
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
979
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
980
        for _, c := range table.cols {
86✔
981
                if c.id != col.id {
133✔
982
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
983
                }
60✔
984
        }
985

986
        row := zeroRow(table.Name(), colSpecs)
13✔
987
        for name, check := range table.checkConstraints {
17✔
988
                _, err := check.exp.reduce(tx, row, table.name)
4✔
989
                if errors.Is(err, ErrColumnDoesNotExist) {
5✔
990
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
991
                }
1✔
992

993
                if err != nil {
3✔
994
                        return err
×
995
                }
×
996
        }
997
        return nil
12✔
998
}
999

1000
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
1001
        mappedKey := MapKey(
9✔
1002
                tx.sqlPrefix(),
9✔
1003
                catalogColumnPrefix,
9✔
1004
                EncodeID(DatabaseID),
9✔
1005
                EncodeID(col.table.id),
9✔
1006
                EncodeID(col.id),
9✔
1007
                []byte(col.colType),
9✔
1008
        )
9✔
1009

9✔
1010
        return tx.delete(ctx, mappedKey)
9✔
1011
}
9✔
1012

1013
type DropConstraintStmt struct {
1014
        table          string
1015
        constraintName string
1016
}
1017

1018
func (stmt *DropConstraintStmt) readOnly() bool {
×
1019
        return false
×
1020
}
×
1021

1022
func (stmt *DropConstraintStmt) requiredPrivileges() []SQLPrivilege {
×
1023
        return []SQLPrivilege{SQLPrivilegeDrop}
×
1024
}
×
1025

1026
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
1027
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
1028
        if err != nil {
4✔
1029
                return nil, err
×
1030
        }
×
1031

1032
        id, err := table.deleteCheck(stmt.constraintName)
4✔
1033
        if err != nil {
5✔
1034
                return nil, err
1✔
1035
        }
1✔
1036

1037
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
1038

3✔
1039
        tx.mutatedCatalog = true
3✔
1040

3✔
1041
        return tx, err
3✔
1042
}
1043

1044
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
1045
        mappedKey := MapKey(
3✔
1046
                tx.sqlPrefix(),
3✔
1047
                catalogCheckPrefix,
3✔
1048
                EncodeID(DatabaseID),
3✔
1049
                EncodeID(tableID),
3✔
1050
                EncodeID(checkId),
3✔
1051
        )
3✔
1052
        return tx.delete(ctx, mappedKey)
3✔
1053
}
3✔
1054

1055
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
1056
        return nil
×
1057
}
×
1058

1059
type UpsertIntoStmt struct {
1060
        isInsert   bool
1061
        tableRef   *tableRef
1062
        cols       []string
1063
        ds         DataSource
1064
        onConflict *OnConflictDo
1065
}
1066

1067
func (stmt *UpsertIntoStmt) readOnly() bool {
101✔
1068
        return false
101✔
1069
}
101✔
1070

1071
func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege {
101✔
1072
        privileges := stmt.privileges()
101✔
1073
        if stmt.ds != nil {
200✔
1074
                privileges = append(privileges, stmt.ds.requiredPrivileges()...)
99✔
1075
        }
99✔
1076
        return privileges
101✔
1077
}
1078

1079
func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege {
101✔
1080
        if stmt.isInsert {
190✔
1081
                return []SQLPrivilege{SQLPrivilegeInsert}
89✔
1082
        }
89✔
1083
        return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate}
12✔
1084
}
1085

1086
func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
1087
        return &UpsertIntoStmt{
120✔
1088
                isInsert:   isInsert,
120✔
1089
                tableRef:   NewTableRef(table, ""),
120✔
1090
                cols:       cols,
120✔
1091
                ds:         ds,
120✔
1092
                onConflict: onConflict,
120✔
1093
        }
120✔
1094
}
120✔
1095

1096
type RowSpec struct {
1097
        Values []ValueExp
1098
}
1099

1100
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
1101
        return &RowSpec{
129✔
1102
                Values: values,
129✔
1103
        }
129✔
1104
}
129✔
1105

1106
type OnConflictDo struct{}
1107

1108
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
1109
        ds, ok := stmt.ds.(*valuesDataSource)
11✔
1110
        if !ok {
11✔
1111
                return stmt.ds.inferParameters(ctx, tx, params)
×
1112
        }
×
1113

1114
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
1115
        for _, row := range ds.rows {
23✔
1116
                if len(stmt.cols) != len(row.Values) {
13✔
1117
                        return ErrInvalidNumberOfValues
1✔
1118
                }
1✔
1119

1120
                for i, val := range row.Values {
36✔
1121
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
1122
                        if err != nil {
26✔
1123
                                return err
1✔
1124
                        }
1✔
1125

1126
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
1127
                        if err != nil {
25✔
1128
                                return err
1✔
1129
                        }
1✔
1130

1131
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
1132
                        if err != nil {
25✔
1133
                                return err
2✔
1134
                        }
2✔
1135
                }
1136
        }
1137
        return nil
6✔
1138
}
1139

1140
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
2,115✔
1141
        selPosByColID := make(map[uint32]int, len(stmt.cols))
2,115✔
1142

2,115✔
1143
        for i, c := range stmt.cols {
9,736✔
1144
                col, err := table.GetColumnByName(c)
7,621✔
1145
                if err != nil {
7,623✔
1146
                        return nil, err
2✔
1147
                }
2✔
1148

1149
                _, duplicated := selPosByColID[col.id]
7,619✔
1150
                if duplicated {
7,620✔
1151
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
1152
                }
1✔
1153

1154
                selPosByColID[col.id] = i
7,618✔
1155
        }
1156

1157
        return selPosByColID, nil
2,112✔
1158
}
1159

1160
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2,118✔
1161
        table, err := stmt.tableRef.referencedTable(tx)
2,118✔
1162
        if err != nil {
2,121✔
1163
                return nil, err
3✔
1164
        }
3✔
1165

1166
        selPosByColID, err := stmt.validate(table)
2,115✔
1167
        if err != nil {
2,118✔
1168
                return nil, err
3✔
1169
        }
3✔
1170

1171
        r := &Row{
2,112✔
1172
                ValuesByPosition: make([]TypedValue, len(table.cols)),
2,112✔
1173
                ValuesBySelector: make(map[string]TypedValue),
2,112✔
1174
        }
2,112✔
1175

2,112✔
1176
        reader, err := stmt.ds.Resolve(ctx, tx, params, nil)
2,112✔
1177
        if err != nil {
2,112✔
1178
                return nil, err
×
1179
        }
×
1180
        defer reader.Close()
2,112✔
1181

2,112✔
1182
        for {
6,490✔
1183
                row, err := reader.Read(ctx)
4,378✔
1184
                if errors.Is(err, ErrNoMoreRows) {
6,447✔
1185
                        break
2,069✔
1186
                }
1187
                if err != nil {
2,319✔
1188
                        return nil, err
10✔
1189
                }
10✔
1190

1191
                if len(row.ValuesByPosition) != len(stmt.cols) {
2,301✔
1192
                        return nil, ErrInvalidNumberOfValues
2✔
1193
                }
2✔
1194

1195
                valuesByColID := make(map[uint32]TypedValue)
2,297✔
1196

2,297✔
1197
                var pkMustExist bool
2,297✔
1198

2,297✔
1199
                for colID, col := range table.colsByID {
11,744✔
1200
                        colPos, specified := selPosByColID[colID]
9,447✔
1201
                        if !specified {
10,693✔
1202
                                // TODO: Default values
1,246✔
1203
                                if col.notNull && !col.autoIncrement {
1,247✔
1204
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
1205
                                }
1✔
1206

1207
                                // inject auto-incremental pk value
1208
                                if stmt.isInsert && col.autoIncrement {
2,303✔
1209
                                        // current implementation assumes only PK can be set as autoincremental
1,058✔
1210
                                        table.maxPK++
1,058✔
1211

1,058✔
1212
                                        pkCol := table.primaryIndex.cols[0]
1,058✔
1213
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
1,058✔
1214

1,058✔
1215
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,924✔
1216
                                                tx.firstInsertedPKs[table.name] = table.maxPK
866✔
1217
                                        }
866✔
1218
                                        tx.lastInsertedPKs[table.name] = table.maxPK
1,058✔
1219
                                }
1220

1221
                                continue
1,245✔
1222
                        }
1223

1224
                        // value was specified
1225
                        cVal := row.ValuesByPosition[colPos]
8,201✔
1226

8,201✔
1227
                        val, err := cVal.substitute(params)
8,201✔
1228
                        if err != nil {
8,201✔
1229
                                return nil, err
×
1230
                        }
×
1231

1232
                        rval, err := val.reduce(tx, nil, table.name)
8,201✔
1233
                        if err != nil {
8,201✔
1234
                                return nil, err
×
1235
                        }
×
1236

1237
                        if rval.IsNull() {
8,299✔
1238
                                if col.notNull || col.autoIncrement {
98✔
1239
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1240
                                }
×
1241

1242
                                continue
98✔
1243
                        }
1244

1245
                        if col.autoIncrement {
8,122✔
1246
                                // validate specified value
19✔
1247
                                nl, isNumber := rval.RawValue().(int64)
19✔
1248
                                if !isNumber {
19✔
1249
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1250
                                }
×
1251

1252
                                pkMustExist = nl <= table.maxPK
19✔
1253

19✔
1254
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1255
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1256
                                }
19✔
1257
                                tx.lastInsertedPKs[table.name] = nl
19✔
1258
                        }
1259

1260
                        valuesByColID[colID] = rval
8,103✔
1261
                }
1262

1263
                for i, col := range table.cols {
11,739✔
1264
                        v := valuesByColID[col.id]
9,443✔
1265

9,443✔
1266
                        if v == nil {
9,727✔
1267
                                v = NewNull(AnyType)
284✔
1268
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
9,448✔
1269
                                s, _ := v.RawValue().(string)
5✔
1270
                                jsonVal, err := NewJsonFromString(s)
5✔
1271
                                if err != nil {
5✔
1272
                                        return nil, err
×
1273
                                }
×
1274
                                v = jsonVal
5✔
1275
                        }
1276

1277
                        r.ValuesByPosition[i] = v
9,443✔
1278
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
9,443✔
1279
                }
1280

1281
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
2,302✔
1282
                        return nil, err
6✔
1283
                }
6✔
1284

1285
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
2,290✔
1286
                if err != nil {
2,295✔
1287
                        return nil, err
5✔
1288
                }
5✔
1289

1290
                // pk entry
1291
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
2,285✔
1292
                if len(mappedPKey) > MaxKeyLen {
2,285✔
1293
                        return nil, ErrMaxKeyLengthExceeded
×
1294
                }
×
1295

1296
                _, err = tx.get(ctx, mappedPKey)
2,285✔
1297
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
2,285✔
1298
                        return nil, err
×
1299
                }
×
1300

1301
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
2,287✔
1302
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1303
                }
2✔
1304

1305
                if stmt.isInsert {
4,385✔
1306
                        if err == nil && stmt.onConflict == nil {
2,106✔
1307
                                return nil, store.ErrKeyAlreadyExists
4✔
1308
                        }
4✔
1309

1310
                        if err == nil && stmt.onConflict != nil {
2,101✔
1311
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1312
                                continue
3✔
1313
                        }
1314
                }
1315

1316
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
2,276✔
1317
                if err != nil {
2,289✔
1318
                        return nil, err
13✔
1319
                }
13✔
1320
        }
1321
        return tx, nil
2,069✔
1322
}
1323

1324
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
2,329✔
1325
        for _, check := range checks {
2,377✔
1326
                val, err := check.exp.reduce(tx, row, table)
48✔
1327
                if err != nil {
49✔
1328
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1329
                }
1✔
1330

1331
                if val.Type() != BooleanType {
47✔
1332
                        return ErrInvalidCheckConstraint
×
1333
                }
×
1334

1335
                if !val.RawValue().(bool) {
54✔
1336
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1337
                }
7✔
1338
        }
1339
        return nil
2,321✔
1340
}
1341

1342
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,464✔
1343
        valbuf := bytes.Buffer{}
2,464✔
1344

2,464✔
1345
        // null values are not serialized
2,464✔
1346
        encodedVals := 0
2,464✔
1347
        for _, v := range valuesByColID {
12,042✔
1348
                if !v.IsNull() {
19,138✔
1349
                        encodedVals++
9,560✔
1350
                }
9,560✔
1351
        }
1352

1353
        b := make([]byte, EncLenLen)
2,464✔
1354
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,464✔
1355

2,464✔
1356
        _, err := valbuf.Write(b)
2,464✔
1357
        if err != nil {
2,464✔
1358
                return nil, err
×
1359
        }
×
1360

1361
        for _, col := range table.cols {
12,302✔
1362
                rval, specified := valuesByColID[col.id]
9,838✔
1363
                if !specified || rval.IsNull() {
10,122✔
1364
                        continue
284✔
1365
                }
1366

1367
                b := make([]byte, EncIDLen)
9,554✔
1368
                binary.BigEndian.PutUint32(b, uint32(col.id))
9,554✔
1369

9,554✔
1370
                _, err = valbuf.Write(b)
9,554✔
1371
                if err != nil {
9,554✔
1372
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1373
                }
×
1374

1375
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
9,554✔
1376
                if err != nil {
9,562✔
1377
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1378
                }
8✔
1379

1380
                _, err = valbuf.Write(encVal)
9,546✔
1381
                if err != nil {
9,546✔
1382
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1383
                }
×
1384
        }
1385

1386
        return valbuf.Bytes(), nil
2,456✔
1387
}
1388

1389
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
2,307✔
1390
        var reusableIndexEntries map[uint32]struct{}
2,307✔
1391

2,307✔
1392
        if reuseIndex && len(table.indexes) > 1 {
2,364✔
1393
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1394
                if err == nil {
93✔
1395
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1396

36✔
1397
                        for _, col := range table.cols {
161✔
1398
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1399
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1400
                        }
125✔
1401

1402
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1403
                        if err != nil {
36✔
1404
                                return err
×
1405
                        }
×
1406
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1407
                        return err
×
1408
                }
×
1409
        }
1410

1411
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
2,307✔
1412

2,307✔
1413
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
2,307✔
1414
        if err != nil {
2,315✔
1415
                return err
8✔
1416
        }
8✔
1417

1418
        err = tx.set(rowKey, nil, encodedRowValue)
2,299✔
1419
        if err != nil {
2,299✔
1420
                return err
×
1421
        }
×
1422

1423
        // create in-memory and validate entries for secondary indexes
1424
        for _, index := range table.indexes {
5,493✔
1425
                if index.IsPrimary() {
5,493✔
1426
                        continue
2,299✔
1427
                }
1428

1429
                if reusableIndexEntries != nil {
972✔
1430
                        _, reusable := reusableIndexEntries[index.id]
77✔
1431
                        if reusable {
127✔
1432
                                continue
50✔
1433
                        }
1434
                }
1435

1436
                encodedValues := make([][]byte, 2+len(index.cols))
845✔
1437
                encodedValues[0] = EncodeID(table.id)
845✔
1438
                encodedValues[1] = EncodeID(index.id)
845✔
1439

845✔
1440
                indexKeyLen := 0
845✔
1441

845✔
1442
                for i, col := range index.cols {
1,765✔
1443
                        rval, specified := valuesByColID[col.id]
920✔
1444
                        if !specified {
993✔
1445
                                rval = &NullValue{t: col.colType}
73✔
1446
                        }
73✔
1447

1448
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
920✔
1449
                        if err != nil {
920✔
1450
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1451
                        }
×
1452

1453
                        if n > MaxKeyLen {
920✔
1454
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1455
                        }
×
1456

1457
                        indexKeyLen += n
920✔
1458

920✔
1459
                        encodedValues[i+2] = encVal
920✔
1460
                }
1461

1462
                if indexKeyLen > MaxKeyLen {
845✔
1463
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1464
                }
×
1465

1466
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
845✔
1467

845✔
1468
                // no other equivalent entry should be already indexed
845✔
1469
                if index.IsUnique() {
922✔
1470
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1471
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1472
                                return store.ErrKeyAlreadyExists
5✔
1473
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1474
                                return err
×
1475
                        }
×
1476
                }
1477

1478
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
840✔
1479
                if err != nil {
840✔
1480
                        return err
×
1481
                }
×
1482
        }
1483

1484
        tx.updatedRows++
2,294✔
1485

2,294✔
1486
        return nil
2,294✔
1487
}
1488

1489
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
13,977✔
1490
        valbuf := bytes.Buffer{}
13,977✔
1491

13,977✔
1492
        indexKeyLen := 0
13,977✔
1493

13,977✔
1494
        for _, col := range index.cols {
27,966✔
1495
                rval, specified := valuesByColID[col.id]
13,989✔
1496
                if !specified || rval.IsNull() {
13,992✔
1497
                        return nil, ErrPKCanNotBeNull
3✔
1498
                }
3✔
1499

1500
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
13,986✔
1501
                if err != nil {
13,988✔
1502
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1503
                }
2✔
1504

1505
                if n > MaxKeyLen {
13,984✔
1506
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1507
                }
×
1508

1509
                indexKeyLen += n
13,984✔
1510

13,984✔
1511
                _, err = valbuf.Write(encVal)
13,984✔
1512
                if err != nil {
13,984✔
1513
                        return nil, err
×
1514
                }
×
1515
        }
1516

1517
        if indexKeyLen > MaxKeyLen {
13,972✔
1518
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1519
        }
×
1520

1521
        return valbuf.Bytes(), nil
13,972✔
1522
}
1523

1524
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1525
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1526

57✔
1527
        for _, pkCol := range table.primaryIndex.cols {
114✔
1528
                pkVal := valuesByColID[pkCol.id]
57✔
1529

57✔
1530
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1531
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1532
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1533
                }
57✔
1534
        }
57✔
1535

1536
        scanSpecs := &ScanSpecs{
57✔
1537
                Index:         table.primaryIndex,
57✔
1538
                rangesByColID: pkRanges,
57✔
1539
        }
57✔
1540

57✔
1541
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1542
        if err != nil {
57✔
1543
                return nil, err
×
1544
        }
×
1545

1546
        defer func() {
114✔
1547
                r.Close()
57✔
1548
        }()
57✔
1549

1550
        return r.Read(ctx)
57✔
1551
}
1552

1553
// deprecateIndexEntries mark previous index entries as deleted
1554
func (tx *SQLTx) deprecateIndexEntries(
1555
        pkEncVals []byte,
1556
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1557
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1558

36✔
1559
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1560
        if err != nil {
36✔
1561
                return nil, err
×
1562
        }
×
1563

1564
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1565

36✔
1566
        for _, index := range table.indexes {
149✔
1567
                if index.IsPrimary() {
149✔
1568
                        continue
36✔
1569
                }
1570

1571
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1572
                encodedValues[0] = EncodeID(table.id)
77✔
1573
                encodedValues[1] = EncodeID(index.id)
77✔
1574
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1575

77✔
1576
                // existent index entry is deleted only if it differs from existent one
77✔
1577
                sameIndexKey := true
77✔
1578

77✔
1579
                for i, col := range index.cols {
159✔
1580
                        currVal, specified := currValuesByColID[col.id]
82✔
1581
                        if !specified {
82✔
1582
                                currVal = &NullValue{t: col.colType}
×
1583
                        }
×
1584

1585
                        newVal, specified := newValuesByColID[col.id]
82✔
1586
                        if !specified {
86✔
1587
                                newVal = &NullValue{t: col.colType}
4✔
1588
                        }
4✔
1589

1590
                        r, err := currVal.Compare(newVal)
82✔
1591
                        if err != nil {
82✔
1592
                                return nil, err
×
1593
                        }
×
1594

1595
                        sameIndexKey = sameIndexKey && r == 0
82✔
1596

82✔
1597
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1598

82✔
1599
                        encodedValues[i+3] = encVal
82✔
1600
                }
1601

1602
                // mark existent index entry as deleted
1603
                if sameIndexKey {
127✔
1604
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1605
                } else {
77✔
1606
                        md := store.NewKVMetadata()
27✔
1607

27✔
1608
                        md.AsDeleted(true)
27✔
1609

27✔
1610
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1611
                        if err != nil {
27✔
1612
                                return nil, err
×
1613
                        }
×
1614
                }
1615
        }
1616

1617
        return reusableIndexEntries, nil
36✔
1618
}
1619

1620
type UpdateStmt struct {
1621
        tableRef *tableRef
1622
        where    ValueExp
1623
        updates  []*colUpdate
1624
        indexOn  []string
1625
        limit    ValueExp
1626
        offset   ValueExp
1627
}
1628

1629
type colUpdate struct {
1630
        col string
1631
        op  CmpOperator
1632
        val ValueExp
1633
}
1634

1635
func (stmt *UpdateStmt) readOnly() bool {
4✔
1636
        return false
4✔
1637
}
4✔
1638

1639
func (stmt *UpdateStmt) requiredPrivileges() []SQLPrivilege {
4✔
1640
        return []SQLPrivilege{SQLPrivilegeUpdate}
4✔
1641
}
4✔
1642

1643
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1644
        selectStmt := &SelectStmt{
1✔
1645
                ds:    stmt.tableRef,
1✔
1646
                where: stmt.where,
1✔
1647
        }
1✔
1648

1✔
1649
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1650
        if err != nil {
1✔
1651
                return err
×
1652
        }
×
1653

1654
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1655
        if err != nil {
1✔
1656
                return err
×
1657
        }
×
1658

1659
        for _, update := range stmt.updates {
2✔
1660
                col, err := table.GetColumnByName(update.col)
1✔
1661
                if err != nil {
1✔
1662
                        return err
×
1663
                }
×
1664

1665
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1666
                if err != nil {
1✔
1667
                        return err
×
1668
                }
×
1669
        }
1670

1671
        return nil
1✔
1672
}
1673

1674
func (stmt *UpdateStmt) validate(table *Table) error {
21✔
1675
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
21✔
1676

21✔
1677
        for _, update := range stmt.updates {
44✔
1678
                if update.op != EQ {
23✔
1679
                        return ErrIllegalArguments
×
1680
                }
×
1681

1682
                col, err := table.GetColumnByName(update.col)
23✔
1683
                if err != nil {
24✔
1684
                        return err
1✔
1685
                }
1✔
1686

1687
                if table.PrimaryIndex().IncludesCol(col.id) {
22✔
1688
                        return ErrPKCanNotBeUpdated
×
1689
                }
×
1690

1691
                _, duplicated := colIDs[col.id]
22✔
1692
                if duplicated {
22✔
1693
                        return ErrDuplicatedColumn
×
1694
                }
×
1695

1696
                colIDs[col.id] = struct{}{}
22✔
1697
        }
1698

1699
        return nil
20✔
1700
}
1701

1702
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
22✔
1703
        selectStmt := &SelectStmt{
22✔
1704
                ds:      stmt.tableRef,
22✔
1705
                where:   stmt.where,
22✔
1706
                indexOn: stmt.indexOn,
22✔
1707
                limit:   stmt.limit,
22✔
1708
                offset:  stmt.offset,
22✔
1709
        }
22✔
1710

22✔
1711
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
22✔
1712
        if err != nil {
23✔
1713
                return nil, err
1✔
1714
        }
1✔
1715
        defer rowReader.Close()
21✔
1716

21✔
1717
        table := rowReader.ScanSpecs().Index.table
21✔
1718

21✔
1719
        err = stmt.validate(table)
21✔
1720
        if err != nil {
22✔
1721
                return nil, err
1✔
1722
        }
1✔
1723

1724
        cols, err := rowReader.colsBySelector(ctx)
20✔
1725
        if err != nil {
20✔
1726
                return nil, err
×
1727
        }
×
1728

1729
        for {
71✔
1730
                row, err := rowReader.Read(ctx)
51✔
1731
                if errors.Is(err, ErrNoMoreRows) {
68✔
1732
                        break
17✔
1733
                } else if err != nil {
35✔
1734
                        return nil, err
1✔
1735
                }
1✔
1736

1737
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
33✔
1738

33✔
1739
                for _, col := range table.cols {
124✔
1740
                        encSel := EncodeSelector("", table.name, col.colName)
91✔
1741
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
91✔
1742
                }
91✔
1743

1744
                for _, update := range stmt.updates {
68✔
1745
                        col, err := table.GetColumnByName(update.col)
35✔
1746
                        if err != nil {
35✔
1747
                                return nil, err
×
1748
                        }
×
1749

1750
                        sval, err := update.val.substitute(params)
35✔
1751
                        if err != nil {
35✔
1752
                                return nil, err
×
1753
                        }
×
1754

1755
                        rval, err := sval.reduce(tx, row, table.name)
35✔
1756
                        if err != nil {
35✔
1757
                                return nil, err
×
1758
                        }
×
1759

1760
                        err = rval.requiresType(col.colType, cols, nil, table.name)
35✔
1761
                        if err != nil {
35✔
1762
                                return nil, err
×
1763
                        }
×
1764

1765
                        valuesByColID[col.id] = rval
35✔
1766
                }
1767

1768
                for i, col := range table.cols {
124✔
1769
                        v := valuesByColID[col.id]
91✔
1770

91✔
1771
                        row.ValuesByPosition[i] = v
91✔
1772
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
91✔
1773
                }
91✔
1774

1775
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
35✔
1776
                        return nil, err
2✔
1777
                }
2✔
1778

1779
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
31✔
1780
                if err != nil {
31✔
1781
                        return nil, err
×
1782
                }
×
1783

1784
                // primary index entry
1785
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
31✔
1786

31✔
1787
                // mkey must exist
31✔
1788
                _, err = tx.get(ctx, mkey)
31✔
1789
                if err != nil {
31✔
1790
                        return nil, err
×
1791
                }
×
1792

1793
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
31✔
1794
                if err != nil {
31✔
1795
                        return nil, err
×
1796
                }
×
1797
        }
1798

1799
        return tx, nil
17✔
1800
}
1801

1802
type DeleteFromStmt struct {
1803
        tableRef *tableRef
1804
        where    ValueExp
1805
        indexOn  []string
1806
        orderBy  []*OrdExp
1807
        limit    ValueExp
1808
        offset   ValueExp
1809
}
1810

1811
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdExp, limit ValueExp) *DeleteFromStmt {
4✔
1812
        return &DeleteFromStmt{
4✔
1813
                tableRef: NewTableRef(table, ""),
4✔
1814
                where:    where,
4✔
1815
                orderBy:  orderBy,
4✔
1816
                limit:    limit,
4✔
1817
        }
4✔
1818
}
4✔
1819

1820
func (stmt *DeleteFromStmt) readOnly() bool {
1✔
1821
        return false
1✔
1822
}
1✔
1823

1824
func (stmt *DeleteFromStmt) requiredPrivileges() []SQLPrivilege {
1✔
1825
        return []SQLPrivilege{SQLPrivilegeDelete}
1✔
1826
}
1✔
1827

1828
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1829
        selectStmt := &SelectStmt{
1✔
1830
                ds:      stmt.tableRef,
1✔
1831
                where:   stmt.where,
1✔
1832
                orderBy: stmt.orderBy,
1✔
1833
        }
1✔
1834
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1835
}
1✔
1836

1837
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
1838
        selectStmt := &SelectStmt{
15✔
1839
                ds:      stmt.tableRef,
15✔
1840
                where:   stmt.where,
15✔
1841
                indexOn: stmt.indexOn,
15✔
1842
                orderBy: stmt.orderBy,
15✔
1843
                limit:   stmt.limit,
15✔
1844
                offset:  stmt.offset,
15✔
1845
        }
15✔
1846

15✔
1847
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
15✔
1848
        if err != nil {
17✔
1849
                return nil, err
2✔
1850
        }
2✔
1851
        defer rowReader.Close()
13✔
1852

13✔
1853
        table := rowReader.ScanSpecs().Index.table
13✔
1854

13✔
1855
        for {
147✔
1856
                row, err := rowReader.Read(ctx)
134✔
1857
                if errors.Is(err, ErrNoMoreRows) {
146✔
1858
                        break
12✔
1859
                }
1860
                if err != nil {
123✔
1861
                        return nil, err
1✔
1862
                }
1✔
1863

1864
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
121✔
1865

121✔
1866
                for _, col := range table.cols {
406✔
1867
                        encSel := EncodeSelector("", table.name, col.colName)
285✔
1868
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
285✔
1869
                }
285✔
1870

1871
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
121✔
1872
                if err != nil {
121✔
1873
                        return nil, err
×
1874
                }
×
1875

1876
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
121✔
1877
                if err != nil {
121✔
1878
                        return nil, err
×
1879
                }
×
1880

1881
                tx.updatedRows++
121✔
1882
        }
1883
        return tx, nil
12✔
1884
}
1885

1886
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
121✔
1887
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
121✔
1888
        if err != nil {
121✔
1889
                return err
×
1890
        }
×
1891

1892
        for _, index := range table.indexes {
291✔
1893
                if !index.IsPrimary() {
219✔
1894
                        continue
49✔
1895
                }
1896

1897
                encodedValues := make([][]byte, 3+len(index.cols))
121✔
1898
                encodedValues[0] = EncodeID(DatabaseID)
121✔
1899
                encodedValues[1] = EncodeID(table.id)
121✔
1900
                encodedValues[2] = EncodeID(index.id)
121✔
1901

121✔
1902
                for i, col := range index.cols {
242✔
1903
                        val, specified := valuesByColID[col.id]
121✔
1904
                        if !specified {
121✔
1905
                                val = &NullValue{t: col.colType}
×
1906
                        }
×
1907

1908
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
121✔
1909

121✔
1910
                        encodedValues[i+3] = encVal
121✔
1911
                }
1912

1913
                md := store.NewKVMetadata()
121✔
1914

121✔
1915
                md.AsDeleted(true)
121✔
1916

121✔
1917
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
121✔
1918
                if err != nil {
121✔
1919
                        return err
×
1920
                }
×
1921
        }
1922

1923
        return nil
121✔
1924
}
1925

1926
type ValueExp interface {
1927
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1928
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1929
        substitute(params map[string]interface{}) (ValueExp, error)
1930
        selectors() []Selector
1931
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1932
        reduceSelectors(row *Row, implicitTable string) ValueExp
1933
        isConstant() bool
1934
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1935
        String() string
1936
}
1937

1938
type typedValueRange struct {
1939
        lRange *typedValueSemiRange
1940
        hRange *typedValueSemiRange
1941
}
1942

1943
type typedValueSemiRange struct {
1944
        val       TypedValue
1945
        inclusive bool
1946
}
1947

1948
func (r *typedValueRange) unitary() bool {
19✔
1949
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1950
        if r.lRange == nil || r.hRange == nil {
19✔
1951
                return false
×
1952
        }
×
1953

1954
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1955
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1956
}
1957

1958
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
5✔
1959
        if r.lRange == nil {
6✔
1960
                r.lRange = refiningRange.lRange
1✔
1961
        } else if r.lRange != nil && refiningRange.lRange != nil {
6✔
1962
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1963
                if err != nil {
1✔
1964
                        return err
×
1965
                }
×
1966
                r.lRange = maxRange
1✔
1967
        }
1968

1969
        if r.hRange == nil {
8✔
1970
                r.hRange = refiningRange.hRange
3✔
1971
        } else if r.hRange != nil && refiningRange.hRange != nil {
7✔
1972
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1973
                if err != nil {
2✔
1974
                        return err
×
1975
                }
×
1976
                r.hRange = minRange
2✔
1977
        }
1978

1979
        return nil
5✔
1980
}
1981

1982
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1983
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1984
                r.lRange = nil
2✔
1985
        } else {
5✔
1986
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1987
                if err != nil {
3✔
1988
                        return err
×
1989
                }
×
1990
                r.lRange = minRange
3✔
1991
        }
1992

1993
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1994
                r.hRange = nil
3✔
1995
        } else {
5✔
1996
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1997
                if err != nil {
2✔
1998
                        return err
×
1999
                }
×
2000
                r.hRange = maxRange
2✔
2001
        }
2002

2003
        return nil
5✔
2004
}
2005

2006
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
2007
        r, err := or1.val.Compare(or2.val)
3✔
2008
        if err != nil {
3✔
2009
                return nil, err
×
2010
        }
×
2011

2012
        maxVal := or1.val
3✔
2013
        if r < 0 {
5✔
2014
                maxVal = or2.val
2✔
2015
        }
2✔
2016

2017
        return &typedValueSemiRange{
3✔
2018
                val:       maxVal,
3✔
2019
                inclusive: or1.inclusive && or2.inclusive,
3✔
2020
        }, nil
3✔
2021
}
2022

2023
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
2024
        r, err := or1.val.Compare(or2.val)
5✔
2025
        if err != nil {
5✔
2026
                return nil, err
×
2027
        }
×
2028

2029
        minVal := or1.val
5✔
2030
        if r > 0 {
9✔
2031
                minVal = or2.val
4✔
2032
        }
4✔
2033

2034
        return &typedValueSemiRange{
5✔
2035
                val:       minVal,
5✔
2036
                inclusive: or1.inclusive || or2.inclusive,
5✔
2037
        }, nil
5✔
2038
}
2039

2040
type TypedValue interface {
2041
        ValueExp
2042
        Type() SQLValueType
2043
        RawValue() interface{}
2044
        Compare(val TypedValue) (int, error)
2045
        IsNull() bool
2046
}
2047

2048
type Tuple []TypedValue
2049

2050
func (t Tuple) Compare(other Tuple) (int, int, error) {
204,158✔
2051
        if len(t) != len(other) {
204,158✔
2052
                return -1, -1, ErrNotComparableValues
×
2053
        }
×
2054

2055
        for i := range t {
431,257✔
2056
                res, err := t[i].Compare(other[i])
227,099✔
2057
                if err != nil || res != 0 {
420,947✔
2058
                        return res, i, err
193,848✔
2059
                }
193,848✔
2060
        }
2061
        return 0, -1, nil
10,310✔
2062
}
2063

2064
func NewNull(t SQLValueType) *NullValue {
391✔
2065
        return &NullValue{t: t}
391✔
2066
}
391✔
2067

2068
type NullValue struct {
2069
        t SQLValueType
2070
}
2071

2072
func (n *NullValue) Type() SQLValueType {
107✔
2073
        return n.t
107✔
2074
}
107✔
2075

2076
func (n *NullValue) RawValue() interface{} {
356✔
2077
        return nil
356✔
2078
}
356✔
2079

2080
func (n *NullValue) IsNull() bool {
381✔
2081
        return true
381✔
2082
}
381✔
2083

2084
func (n *NullValue) String() string {
4✔
2085
        return "NULL"
4✔
2086
}
4✔
2087

2088
func (n *NullValue) Compare(val TypedValue) (int, error) {
85✔
2089
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
87✔
2090
                return 0, ErrNotComparableValues
2✔
2091
        }
2✔
2092

2093
        if val.RawValue() == nil {
124✔
2094
                return 0, nil
41✔
2095
        }
41✔
2096
        return -1, nil
42✔
2097
}
2098

2099
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
8✔
2100
        return v.t, nil
8✔
2101
}
8✔
2102

2103
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
10✔
2104
        if v.t == t {
15✔
2105
                return nil
5✔
2106
        }
5✔
2107

2108
        if v.t != AnyType {
6✔
2109
                return ErrInvalidTypes
1✔
2110
        }
1✔
2111

2112
        v.t = t
4✔
2113

4✔
2114
        return nil
4✔
2115
}
2116

2117
func (v *NullValue) selectors() []Selector {
13✔
2118
        return nil
13✔
2119
}
13✔
2120

2121
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
394✔
2122
        return v, nil
394✔
2123
}
394✔
2124

2125
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
352✔
2126
        return v, nil
352✔
2127
}
352✔
2128

2129
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
2130
        return v
10✔
2131
}
10✔
2132

2133
func (v *NullValue) isConstant() bool {
12✔
2134
        return true
12✔
2135
}
12✔
2136

2137
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2138
        return nil
1✔
2139
}
1✔
2140

2141
type Integer struct {
2142
        val int64
2143
}
2144

2145
func NewInteger(val int64) *Integer {
315✔
2146
        return &Integer{val: val}
315✔
2147
}
315✔
2148

2149
func (v *Integer) Type() SQLValueType {
302,814✔
2150
        return IntegerType
302,814✔
2151
}
302,814✔
2152

2153
func (v *Integer) IsNull() bool {
116,553✔
2154
        return false
116,553✔
2155
}
116,553✔
2156

2157
func (v *Integer) String() string {
54✔
2158
        return strconv.FormatInt(v.val, 10)
54✔
2159
}
54✔
2160

2161
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
98✔
2162
        return IntegerType, nil
98✔
2163
}
98✔
2164

2165
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
63✔
2166
        if t != IntegerType && t != JSONType {
67✔
2167
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
2168
        }
4✔
2169
        return nil
59✔
2170
}
2171

2172
func (v *Integer) selectors() []Selector {
51✔
2173
        return nil
51✔
2174
}
51✔
2175

2176
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
15,159✔
2177
        return v, nil
15,159✔
2178
}
15,159✔
2179

2180
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
16,908✔
2181
        return v, nil
16,908✔
2182
}
16,908✔
2183

2184
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
9✔
2185
        return v
9✔
2186
}
9✔
2187

2188
func (v *Integer) isConstant() bool {
120✔
2189
        return true
120✔
2190
}
120✔
2191

2192
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2193
        return nil
1✔
2194
}
1✔
2195

2196
func (v *Integer) RawValue() interface{} {
177,190✔
2197
        return v.val
177,190✔
2198
}
177,190✔
2199

2200
func (v *Integer) Compare(val TypedValue) (int, error) {
93,042✔
2201
        if val.IsNull() {
93,085✔
2202
                return 1, nil
43✔
2203
        }
43✔
2204

2205
        if val.Type() == JSONType {
93,000✔
2206
                res, err := val.Compare(v)
1✔
2207
                return -res, err
1✔
2208
        }
1✔
2209

2210
        if val.Type() == Float64Type {
92,998✔
2211
                r, err := val.Compare(v)
×
2212
                return r * -1, err
×
2213
        }
×
2214

2215
        if val.Type() != IntegerType {
93,005✔
2216
                return 0, ErrNotComparableValues
7✔
2217
        }
7✔
2218

2219
        rval := val.RawValue().(int64)
92,991✔
2220

92,991✔
2221
        if v.val == rval {
111,639✔
2222
                return 0, nil
18,648✔
2223
        }
18,648✔
2224

2225
        if v.val > rval {
109,220✔
2226
                return 1, nil
34,877✔
2227
        }
34,877✔
2228

2229
        return -1, nil
39,466✔
2230
}
2231

2232
type Timestamp struct {
2233
        val time.Time
2234
}
2235

2236
func (v *Timestamp) Type() SQLValueType {
40,315✔
2237
        return TimestampType
40,315✔
2238
}
40,315✔
2239

2240
func (v *Timestamp) IsNull() bool {
32,894✔
2241
        return false
32,894✔
2242
}
32,894✔
2243

2244
func (v *Timestamp) String() string {
1✔
2245
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2246
}
1✔
2247

2248
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2249
        return TimestampType, nil
1✔
2250
}
1✔
2251

2252
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
14✔
2253
        if t != TimestampType {
15✔
2254
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2255
        }
1✔
2256

2257
        return nil
13✔
2258
}
2259

2260
func (v *Timestamp) selectors() []Selector {
1✔
2261
        return nil
1✔
2262
}
1✔
2263

2264
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
1,165✔
2265
        return v, nil
1,165✔
2266
}
1,165✔
2267

2268
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,025✔
2269
        return v, nil
2,025✔
2270
}
2,025✔
2271

2272
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2273
        return v
1✔
2274
}
1✔
2275

2276
func (v *Timestamp) isConstant() bool {
1✔
2277
        return true
1✔
2278
}
1✔
2279

2280
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2281
        return nil
1✔
2282
}
1✔
2283

2284
func (v *Timestamp) RawValue() interface{} {
57,410✔
2285
        return v.val
57,410✔
2286
}
57,410✔
2287

2288
func (v *Timestamp) Compare(val TypedValue) (int, error) {
29,744✔
2289
        if val.IsNull() {
29,746✔
2290
                return 1, nil
2✔
2291
        }
2✔
2292

2293
        if val.Type() != TimestampType {
29,743✔
2294
                return 0, ErrNotComparableValues
1✔
2295
        }
1✔
2296

2297
        rval := val.RawValue().(time.Time)
29,741✔
2298

29,741✔
2299
        if v.val.Before(rval) {
44,251✔
2300
                return -1, nil
14,510✔
2301
        }
14,510✔
2302

2303
        if v.val.After(rval) {
30,271✔
2304
                return 1, nil
15,040✔
2305
        }
15,040✔
2306

2307
        return 0, nil
191✔
2308
}
2309

2310
type Varchar struct {
2311
        val string
2312
}
2313

2314
func NewVarchar(val string) *Varchar {
2,137✔
2315
        return &Varchar{val: val}
2,137✔
2316
}
2,137✔
2317

2318
func (v *Varchar) Type() SQLValueType {
127,946✔
2319
        return VarcharType
127,946✔
2320
}
127,946✔
2321

2322
func (v *Varchar) IsNull() bool {
64,574✔
2323
        return false
64,574✔
2324
}
64,574✔
2325

2326
func (v *Varchar) String() string {
18✔
2327
        return fmt.Sprintf("'%s'", v.val)
18✔
2328
}
18✔
2329

2330
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
61✔
2331
        return VarcharType, nil
61✔
2332
}
61✔
2333

2334
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
142✔
2335
        if t != VarcharType && t != JSONType {
146✔
2336
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
4✔
2337
        }
4✔
2338
        return nil
138✔
2339
}
2340

2341
func (v *Varchar) selectors() []Selector {
39✔
2342
        return nil
39✔
2343
}
39✔
2344

2345
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
4,950✔
2346
        return v, nil
4,950✔
2347
}
4,950✔
2348

2349
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5,638✔
2350
        return v, nil
5,638✔
2351
}
5,638✔
2352

2353
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2354
        return v
×
2355
}
×
2356

2357
func (v *Varchar) isConstant() bool {
39✔
2358
        return true
39✔
2359
}
39✔
2360

2361
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2362
        return nil
1✔
2363
}
1✔
2364

2365
func (v *Varchar) RawValue() interface{} {
90,162✔
2366
        return v.val
90,162✔
2367
}
90,162✔
2368

2369
func (v *Varchar) Compare(val TypedValue) (int, error) {
58,122✔
2370
        if val.IsNull() {
58,178✔
2371
                return 1, nil
56✔
2372
        }
56✔
2373

2374
        if val.Type() == JSONType {
59,067✔
2375
                res, err := val.Compare(v)
1,001✔
2376
                return -res, err
1,001✔
2377
        }
1,001✔
2378

2379
        if val.Type() != VarcharType {
57,066✔
2380
                return 0, ErrNotComparableValues
1✔
2381
        }
1✔
2382

2383
        rval := val.RawValue().(string)
57,064✔
2384

57,064✔
2385
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
57,064✔
2386
}
2387

2388
type UUID struct {
2389
        val uuid.UUID
2390
}
2391

2392
func NewUUID(val uuid.UUID) *UUID {
1✔
2393
        return &UUID{val: val}
1✔
2394
}
1✔
2395

2396
func (v *UUID) Type() SQLValueType {
10✔
2397
        return UUIDType
10✔
2398
}
10✔
2399

2400
func (v *UUID) IsNull() bool {
26✔
2401
        return false
26✔
2402
}
26✔
2403

2404
func (v *UUID) String() string {
1✔
2405
        return v.val.String()
1✔
2406
}
1✔
2407

2408
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2409
        return UUIDType, nil
1✔
2410
}
1✔
2411

2412
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2413
        if t != UUIDType {
6✔
2414
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2415
        }
2✔
2416

2417
        return nil
2✔
2418
}
2419

2420
func (v *UUID) selectors() []Selector {
1✔
2421
        return nil
1✔
2422
}
1✔
2423

2424
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
6✔
2425
        return v, nil
6✔
2426
}
6✔
2427

2428
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5✔
2429
        return v, nil
5✔
2430
}
5✔
2431

2432
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2433
        return v
1✔
2434
}
1✔
2435

2436
func (v *UUID) isConstant() bool {
1✔
2437
        return true
1✔
2438
}
1✔
2439

2440
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2441
        return nil
1✔
2442
}
1✔
2443

2444
func (v *UUID) RawValue() interface{} {
41✔
2445
        return v.val
41✔
2446
}
41✔
2447

2448
func (v *UUID) Compare(val TypedValue) (int, error) {
5✔
2449
        if val.IsNull() {
7✔
2450
                return 1, nil
2✔
2451
        }
2✔
2452

2453
        if val.Type() != UUIDType {
4✔
2454
                return 0, ErrNotComparableValues
1✔
2455
        }
1✔
2456

2457
        rval := val.RawValue().(uuid.UUID)
2✔
2458

2✔
2459
        return bytes.Compare(v.val[:], rval[:]), nil
2✔
2460
}
2461

2462
type Bool struct {
2463
        val bool
2464
}
2465

2466
func NewBool(val bool) *Bool {
208✔
2467
        return &Bool{val: val}
208✔
2468
}
208✔
2469

2470
func (v *Bool) Type() SQLValueType {
1,970✔
2471
        return BooleanType
1,970✔
2472
}
1,970✔
2473

2474
func (v *Bool) IsNull() bool {
1,479✔
2475
        return false
1,479✔
2476
}
1,479✔
2477

2478
func (v *Bool) String() string {
41✔
2479
        return strconv.FormatBool(v.val)
41✔
2480
}
41✔
2481

2482
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
32✔
2483
        return BooleanType, nil
32✔
2484
}
32✔
2485

2486
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
57✔
2487
        if t != BooleanType && t != JSONType {
62✔
2488
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2489
        }
5✔
2490
        return nil
52✔
2491
}
2492

2493
func (v *Bool) selectors() []Selector {
5✔
2494
        return nil
5✔
2495
}
5✔
2496

2497
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
639✔
2498
        return v, nil
639✔
2499
}
639✔
2500

2501
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2502
        return v, nil
726✔
2503
}
726✔
2504

2505
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2506
        return v
×
2507
}
×
2508

2509
func (v *Bool) isConstant() bool {
3✔
2510
        return true
3✔
2511
}
3✔
2512

2513
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
2514
        return nil
8✔
2515
}
8✔
2516

2517
func (v *Bool) RawValue() interface{} {
1,716✔
2518
        return v.val
1,716✔
2519
}
1,716✔
2520

2521
func (v *Bool) Compare(val TypedValue) (int, error) {
575✔
2522
        if val.IsNull() {
605✔
2523
                return 1, nil
30✔
2524
        }
30✔
2525

2526
        if val.Type() == JSONType {
546✔
2527
                res, err := val.Compare(v)
1✔
2528
                return -res, err
1✔
2529
        }
1✔
2530

2531
        if val.Type() != BooleanType {
544✔
2532
                return 0, ErrNotComparableValues
×
2533
        }
×
2534

2535
        rval := val.RawValue().(bool)
544✔
2536

544✔
2537
        if v.val == rval {
888✔
2538
                return 0, nil
344✔
2539
        }
344✔
2540

2541
        if v.val {
206✔
2542
                return 1, nil
6✔
2543
        }
6✔
2544

2545
        return -1, nil
194✔
2546
}
2547

2548
type Blob struct {
2549
        val []byte
2550
}
2551

2552
func NewBlob(val []byte) *Blob {
286✔
2553
        return &Blob{val: val}
286✔
2554
}
286✔
2555

2556
func (v *Blob) Type() SQLValueType {
53✔
2557
        return BLOBType
53✔
2558
}
53✔
2559

2560
func (v *Blob) IsNull() bool {
2,312✔
2561
        return false
2,312✔
2562
}
2,312✔
2563

2564
func (v *Blob) String() string {
2✔
2565
        return hex.EncodeToString(v.val)
2✔
2566
}
2✔
2567

2568
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2569
        return BLOBType, nil
1✔
2570
}
1✔
2571

2572
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2573
        if t != BLOBType {
3✔
2574
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2575
        }
1✔
2576

2577
        return nil
1✔
2578
}
2579

2580
func (v *Blob) selectors() []Selector {
1✔
2581
        return nil
1✔
2582
}
1✔
2583

2584
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
714✔
2585
        return v, nil
714✔
2586
}
714✔
2587

2588
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2589
        return v, nil
726✔
2590
}
726✔
2591

2592
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2593
        return v
×
2594
}
×
2595

2596
func (v *Blob) isConstant() bool {
7✔
2597
        return true
7✔
2598
}
7✔
2599

2600
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2601
        return nil
×
2602
}
×
2603

2604
func (v *Blob) RawValue() interface{} {
2,570✔
2605
        return v.val
2,570✔
2606
}
2,570✔
2607

2608
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
2609
        if val.IsNull() {
25✔
2610
                return 1, nil
×
2611
        }
×
2612

2613
        if val.Type() != BLOBType {
25✔
2614
                return 0, ErrNotComparableValues
×
2615
        }
×
2616

2617
        rval := val.RawValue().([]byte)
25✔
2618

25✔
2619
        return bytes.Compare(v.val, rval), nil
25✔
2620
}
2621

2622
type Float64 struct {
2623
        val float64
2624
}
2625

2626
func NewFloat64(val float64) *Float64 {
1,249✔
2627
        return &Float64{val: val}
1,249✔
2628
}
1,249✔
2629

2630
func (v *Float64) Type() SQLValueType {
207,438✔
2631
        return Float64Type
207,438✔
2632
}
207,438✔
2633

2634
func (v *Float64) IsNull() bool {
5,598✔
2635
        return false
5,598✔
2636
}
5,598✔
2637

2638
func (v *Float64) String() string {
3✔
2639
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2640
}
3✔
2641

2642
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2643
        return Float64Type, nil
15✔
2644
}
15✔
2645

2646
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2647
        if t != Float64Type && t != JSONType {
21✔
2648
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2649
        }
1✔
2650
        return nil
19✔
2651
}
2652

2653
func (v *Float64) selectors() []Selector {
2✔
2654
        return nil
2✔
2655
}
2✔
2656

2657
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
1,817✔
2658
        return v, nil
1,817✔
2659
}
1,817✔
2660

2661
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,428✔
2662
        return v, nil
3,428✔
2663
}
3,428✔
2664

2665
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2666
        return v
1✔
2667
}
1✔
2668

2669
func (v *Float64) isConstant() bool {
5✔
2670
        return true
5✔
2671
}
5✔
2672

2673
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2674
        return nil
1✔
2675
}
1✔
2676

2677
func (v *Float64) RawValue() interface{} {
365,967✔
2678
        return v.val
365,967✔
2679
}
365,967✔
2680

2681
func (v *Float64) Compare(val TypedValue) (int, error) {
61,874✔
2682
        if val.Type() == JSONType {
61,875✔
2683
                res, err := val.Compare(v)
1✔
2684
                return -res, err
1✔
2685
        }
1✔
2686

2687
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
61,873✔
2688
        if err != nil {
61,874✔
2689
                return 0, err
1✔
2690
        }
1✔
2691

2692
        if convVal == nil {
61,875✔
2693
                return 1, nil
3✔
2694
        }
3✔
2695

2696
        rval, ok := convVal.(float64)
61,869✔
2697
        if !ok {
61,869✔
2698
                return 0, ErrNotComparableValues
×
2699
        }
×
2700

2701
        if v.val == rval {
61,993✔
2702
                return 0, nil
124✔
2703
        }
124✔
2704

2705
        if v.val > rval {
90,617✔
2706
                return 1, nil
28,872✔
2707
        }
28,872✔
2708

2709
        return -1, nil
32,873✔
2710
}
2711

2712
type FnCall struct {
2713
        fn     string
2714
        params []ValueExp
2715
}
2716

2717
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
24✔
2718
        fn, err := v.resolveFunc()
24✔
2719
        if err != nil {
25✔
2720
                return AnyType, nil
1✔
2721
        }
1✔
2722
        return fn.InferType(cols, params, implicitTable)
23✔
2723
}
2724

2725
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2726
        fn, err := v.resolveFunc()
20✔
2727
        if err != nil {
21✔
2728
                return err
1✔
2729
        }
1✔
2730
        return fn.RequiresType(t, cols, params, implicitTable)
19✔
2731
}
2732

2733
func (v *FnCall) selectors() []Selector {
33✔
2734
        selectors := make([]Selector, 0)
33✔
2735
        for _, param := range v.params {
89✔
2736
                selectors = append(selectors, param.selectors()...)
56✔
2737
        }
56✔
2738
        return selectors
33✔
2739
}
2740

2741
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
433✔
2742
        ps := make([]ValueExp, len(v.params))
433✔
2743
        for i, p := range v.params {
791✔
2744
                ps[i], err = p.substitute(params)
358✔
2745
                if err != nil {
358✔
2746
                        return nil, err
×
2747
                }
×
2748
        }
2749

2750
        return &FnCall{
433✔
2751
                fn:     v.fn,
433✔
2752
                params: ps,
433✔
2753
        }, nil
433✔
2754
}
2755

2756
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
433✔
2757
        fn, err := v.resolveFunc()
433✔
2758
        if err != nil {
434✔
2759
                return nil, err
1✔
2760
        }
1✔
2761

2762
        fnInputs, err := v.reduceParams(tx, row, implicitTable)
432✔
2763
        if err != nil {
432✔
2764
                return nil, err
×
2765
        }
×
2766
        return fn.Apply(tx, fnInputs)
432✔
2767
}
2768

2769
func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) {
432✔
2770
        var values []TypedValue
432✔
2771
        if len(v.params) > 0 {
763✔
2772
                values = make([]TypedValue, len(v.params))
331✔
2773
                for i, p := range v.params {
689✔
2774
                        v, err := p.reduce(tx, row, implicitTable)
358✔
2775
                        if err != nil {
358✔
2776
                                return nil, err
×
2777
                        }
×
2778
                        values[i] = v
358✔
2779
                }
2780
        }
2781
        return values, nil
432✔
2782
}
2783

2784
func (v *FnCall) resolveFunc() (Function, error) {
477✔
2785
        fn, exists := builtinFunctions[strings.ToUpper(v.fn)]
477✔
2786
        if !exists {
480✔
2787
                return nil, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
3✔
2788
        }
3✔
2789
        return fn, nil
474✔
2790
}
2791

2792
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2793
        return v
×
2794
}
×
2795

2796
func (v *FnCall) isConstant() bool {
13✔
2797
        return false
13✔
2798
}
13✔
2799

2800
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2801
        return nil
×
2802
}
×
2803

2804
func (v *FnCall) String() string {
1✔
2805
        params := make([]string, len(v.params))
1✔
2806
        for i, p := range v.params {
4✔
2807
                params[i] = p.String()
3✔
2808
        }
3✔
2809
        return v.fn + "(" + strings.Join(params, ",") + ")"
1✔
2810
}
2811

2812
type Cast struct {
2813
        val ValueExp
2814
        t   SQLValueType
2815
}
2816

2817
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2818
        _, err := c.val.inferType(cols, params, implicitTable)
23✔
2819
        if err != nil {
24✔
2820
                return AnyType, err
1✔
2821
        }
1✔
2822

2823
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2824

2825
        return c.t, nil
22✔
2826
}
2827

2828
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2829
        if c.t != t {
×
2830
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2831
        }
×
2832

2833
        return nil
×
2834
}
2835

2836
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
281✔
2837
        val, err := c.val.substitute(params)
281✔
2838
        if err != nil {
281✔
2839
                return nil, err
×
2840
        }
×
2841
        c.val = val
281✔
2842
        return c, nil
281✔
2843
}
2844

2845
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
269✔
2846
        val, err := c.val.reduce(tx, row, implicitTable)
269✔
2847
        if err != nil {
269✔
2848
                return nil, err
×
2849
        }
×
2850

2851
        conv, err := getConverter(val.Type(), c.t)
269✔
2852
        if conv == nil {
272✔
2853
                return nil, err
3✔
2854
        }
3✔
2855

2856
        return conv(val)
266✔
2857
}
2858

2859
func (v *Cast) selectors() []Selector {
6✔
2860
        return v.val.selectors()
6✔
2861
}
6✔
2862

2863
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2864
        return &Cast{
×
2865
                val: c.val.reduceSelectors(row, implicitTable),
×
2866
                t:   c.t,
×
2867
        }
×
2868
}
×
2869

2870
func (c *Cast) isConstant() bool {
7✔
2871
        return c.val.isConstant()
7✔
2872
}
7✔
2873

2874
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2875
        return nil
×
2876
}
×
2877

2878
func (c *Cast) String() string {
1✔
2879
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2880
}
1✔
2881

2882
type Param struct {
2883
        id  string
2884
        pos int
2885
}
2886

2887
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
58✔
2888
        t, ok := params[v.id]
58✔
2889
        if !ok {
114✔
2890
                params[v.id] = AnyType
56✔
2891
                return AnyType, nil
56✔
2892
        }
56✔
2893

2894
        return t, nil
2✔
2895
}
2896

2897
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2898
        currT, ok := params[v.id]
76✔
2899
        if ok && currT != t && currT != AnyType {
80✔
2900
                return ErrInferredMultipleTypes
4✔
2901
        }
4✔
2902

2903
        params[v.id] = t
72✔
2904

72✔
2905
        return nil
72✔
2906
}
2907

2908
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
6,399✔
2909
        val, ok := params[p.id]
6,399✔
2910
        if !ok {
6,461✔
2911
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2912
        }
62✔
2913

2914
        if val == nil {
6,386✔
2915
                return &NullValue{t: AnyType}, nil
49✔
2916
        }
49✔
2917

2918
        switch v := val.(type) {
6,288✔
2919
        case bool:
96✔
2920
                {
192✔
2921
                        return &Bool{val: v}, nil
96✔
2922
                }
96✔
2923
        case string:
1,752✔
2924
                {
3,504✔
2925
                        return &Varchar{val: v}, nil
1,752✔
2926
                }
1,752✔
2927
        case int:
1,678✔
2928
                {
3,356✔
2929
                        return &Integer{val: int64(v)}, nil
1,678✔
2930
                }
1,678✔
2931
        case uint:
×
2932
                {
×
2933
                        return &Integer{val: int64(v)}, nil
×
2934
                }
×
2935
        case uint64:
34✔
2936
                {
68✔
2937
                        return &Integer{val: int64(v)}, nil
34✔
2938
                }
34✔
2939
        case int64:
227✔
2940
                {
454✔
2941
                        return &Integer{val: v}, nil
227✔
2942
                }
227✔
2943
        case []byte:
14✔
2944
                {
28✔
2945
                        return &Blob{val: v}, nil
14✔
2946
                }
14✔
2947
        case time.Time:
861✔
2948
                {
1,722✔
2949
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
861✔
2950
                }
861✔
2951
        case float64:
1,625✔
2952
                {
3,250✔
2953
                        return &Float64{val: v}, nil
1,625✔
2954
                }
1,625✔
2955
        }
2956
        return nil, ErrUnsupportedParameter
1✔
2957
}
2958

2959
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2960
        return nil, ErrUnexpected
×
2961
}
×
2962

2963
func (p *Param) selectors() []Selector {
4✔
2964
        return nil
4✔
2965
}
4✔
2966

2967
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2968
        return p
×
2969
}
×
2970

2971
func (p *Param) isConstant() bool {
130✔
2972
        return true
130✔
2973
}
130✔
2974

2975
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
6✔
2976
        return nil
6✔
2977
}
6✔
2978

2979
func (v *Param) String() string {
2✔
2980
        return "@" + v.id
2✔
2981
}
2✔
2982

2983
type whenThenClause struct {
2984
        when, then ValueExp
2985
}
2986

2987
type CaseWhenExp struct {
2988
        exp      ValueExp
2989
        whenThen []whenThenClause
2990
        elseExp  ValueExp
2991
}
2992

2993
func (ce *CaseWhenExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
8✔
2994
        checkType := func(e ValueExp, expectedType SQLValueType) (string, error) {
18✔
2995
                t, err := e.inferType(cols, params, implicitTable)
10✔
2996
                if err != nil {
10✔
2997
                        return "", err
×
2998
                }
×
2999

3000
                if expectedType == AnyType {
15✔
3001
                        return t, nil
5✔
3002
                }
5✔
3003

3004
                if t != expectedType {
6✔
3005
                        if (t == Float64Type && expectedType == IntegerType) ||
1✔
3006
                                (t == IntegerType && expectedType == Float64Type) {
1✔
3007
                                return Float64Type, nil
×
3008
                        }
×
3009
                        return "", fmt.Errorf("%w: CASE types %s and %s cannot be matched", ErrInferredMultipleTypes, expectedType, t)
1✔
3010
                }
3011
                return t, nil
4✔
3012
        }
3013

3014
        searchType := BooleanType
8✔
3015
        inferredResType := AnyType
8✔
3016
        if ce.exp != nil {
11✔
3017
                t, err := ce.exp.inferType(cols, params, implicitTable)
3✔
3018
                if err != nil {
3✔
3019
                        return "", err
×
3020
                }
×
3021
                searchType = t
3✔
3022
        }
3023

3024
        for _, e := range ce.whenThen {
16✔
3025
                whenType, err := e.when.inferType(cols, params, implicitTable)
8✔
3026
                if err != nil {
8✔
3027
                        return "", err
×
3028
                }
×
3029

3030
                if whenType != searchType {
11✔
3031
                        return "", fmt.Errorf("%w: argument of CASE/WHEN must be of type %s, not type %s", ErrInvalidTypes, searchType, whenType)
3✔
3032
                }
3✔
3033

3034
                t, err := checkType(e.then, inferredResType)
5✔
3035
                if err != nil {
5✔
3036
                        return "", err
×
3037
                }
×
3038
                inferredResType = t
5✔
3039
        }
3040

3041
        if ce.elseExp != nil {
10✔
3042
                return checkType(ce.elseExp, inferredResType)
5✔
3043
        }
5✔
3044
        return inferredResType, nil
×
3045
}
3046

3047
func (ce *CaseWhenExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
3048
        inferredType, err := ce.inferType(cols, params, implicitTable)
2✔
3049
        if err != nil {
3✔
3050
                return err
1✔
3051
        }
1✔
3052

3053
        if inferredType != t {
1✔
3054
                return fmt.Errorf("%w: expected type %s but %s found instead", ErrInvalidTypes, t, inferredType)
×
3055
        }
×
3056
        return nil
1✔
3057
}
3058

3059
func (ce *CaseWhenExp) substitute(params map[string]interface{}) (ValueExp, error) {
504✔
3060
        var exp ValueExp
504✔
3061
        if ce.exp != nil {
605✔
3062
                e, err := ce.exp.substitute(params)
101✔
3063
                if err != nil {
101✔
3064
                        return nil, err
×
3065
                }
×
3066
                exp = e
101✔
3067
        }
3068

3069
        whenThen := make([]whenThenClause, len(ce.whenThen))
504✔
3070
        for i, wt := range ce.whenThen {
1,208✔
3071
                whenValue, err := wt.when.substitute(params)
704✔
3072
                if err != nil {
704✔
3073
                        return nil, err
×
3074
                }
×
3075
                whenThen[i].when = whenValue
704✔
3076

704✔
3077
                thenValue, err := wt.then.substitute(params)
704✔
3078
                if err != nil {
704✔
3079
                        return nil, err
×
3080
                }
×
3081
                whenThen[i].then = thenValue
704✔
3082
        }
3083

3084
        if ce.elseExp == nil {
506✔
3085
                return &CaseWhenExp{
2✔
3086
                        exp:      exp,
2✔
3087
                        whenThen: whenThen,
2✔
3088
                }, nil
2✔
3089
        }
2✔
3090

3091
        elseValue, err := ce.elseExp.substitute(params)
502✔
3092
        if err != nil {
502✔
3093
                return nil, err
×
3094
        }
×
3095
        return &CaseWhenExp{
502✔
3096
                exp:      exp,
502✔
3097
                whenThen: whenThen,
502✔
3098
                elseExp:  elseValue,
502✔
3099
        }, nil
502✔
3100
}
3101

3102
func (ce *CaseWhenExp) selectors() []Selector {
7✔
3103
        selectors := make([]Selector, 0)
7✔
3104
        if ce.exp != nil {
8✔
3105
                selectors = append(selectors, ce.exp.selectors()...)
1✔
3106
        }
1✔
3107

3108
        for _, wh := range ce.whenThen {
16✔
3109
                selectors = append(selectors, wh.when.selectors()...)
9✔
3110
                selectors = append(selectors, wh.then.selectors()...)
9✔
3111
        }
9✔
3112

3113
        if ce.elseExp == nil {
9✔
3114
                return selectors
2✔
3115
        }
2✔
3116
        return append(selectors, ce.elseExp.selectors()...)
5✔
3117
}
3118

3119
func (ce *CaseWhenExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
304✔
3120
        var searchValue TypedValue
304✔
3121
        if ce.exp != nil {
405✔
3122
                v, err := ce.exp.reduce(tx, row, implicitTable)
101✔
3123
                if err != nil {
101✔
3124
                        return nil, err
×
3125
                }
×
3126
                searchValue = v
101✔
3127
        } else {
203✔
3128
                searchValue = &Bool{val: true}
203✔
3129
        }
203✔
3130

3131
        for _, wt := range ce.whenThen {
734✔
3132
                v, err := wt.when.reduce(tx, row, implicitTable)
430✔
3133
                if err != nil {
430✔
3134
                        return nil, err
×
3135
                }
×
3136

3137
                if v.Type() != searchValue.Type() {
431✔
3138
                        return nil, fmt.Errorf("%w: argument of CASE/WHEN must be type %s, not type %s", ErrInvalidTypes, v.Type(), searchValue.Type())
1✔
3139
                }
1✔
3140

3141
                res, err := v.Compare(searchValue)
429✔
3142
                if err != nil {
429✔
3143
                        return nil, err
×
3144
                }
×
3145
                if res == 0 {
629✔
3146
                        return wt.then.reduce(tx, row, implicitTable)
200✔
3147
                }
200✔
3148
        }
3149

3150
        if ce.elseExp == nil {
104✔
3151
                return NewNull(AnyType), nil
1✔
3152
        }
1✔
3153
        return ce.elseExp.reduce(tx, row, implicitTable)
102✔
3154
}
3155

3156
func (ce *CaseWhenExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3157
        whenThen := make([]whenThenClause, len(ce.whenThen))
1✔
3158
        for i, wt := range ce.whenThen {
2✔
3159
                whenValue := wt.when.reduceSelectors(row, implicitTable)
1✔
3160
                whenThen[i].when = whenValue
1✔
3161

1✔
3162
                thenValue := wt.then.reduceSelectors(row, implicitTable)
1✔
3163
                whenThen[i].then = thenValue
1✔
3164
        }
1✔
3165

3166
        if ce.elseExp == nil {
1✔
3167
                return &CaseWhenExp{
×
3168
                        whenThen: whenThen,
×
3169
                }
×
3170
        }
×
3171

3172
        return &CaseWhenExp{
1✔
3173
                whenThen: whenThen,
1✔
3174
                elseExp:  ce.elseExp.reduceSelectors(row, implicitTable),
1✔
3175
        }
1✔
3176
}
3177

3178
func (ce *CaseWhenExp) isConstant() bool {
1✔
3179
        return false
1✔
3180
}
1✔
3181

3182
func (ce *CaseWhenExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
3183
        return nil
1✔
3184
}
1✔
3185

3186
func (ce *CaseWhenExp) String() string {
3✔
3187
        var sb strings.Builder
3✔
3188
        for _, wh := range ce.whenThen {
7✔
3189
                sb.WriteString(fmt.Sprintf("WHEN %s THEN %s ", wh.when.String(), wh.then.String()))
4✔
3190
        }
4✔
3191

3192
        if ce.elseExp != nil {
5✔
3193
                sb.WriteString("ELSE " + ce.elseExp.String() + " ")
2✔
3194
        }
2✔
3195
        return "CASE " + sb.String() + "END"
3✔
3196
}
3197

3198
type Comparison int
3199

3200
const (
3201
        EqualTo Comparison = iota
3202
        LowerThan
3203
        LowerOrEqualTo
3204
        GreaterThan
3205
        GreaterOrEqualTo
3206
)
3207

3208
type DataSource interface {
3209
        SQLStmt
3210
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
3211
        Alias() string
3212
}
3213

3214
type TargetEntry struct {
3215
        Exp ValueExp
3216
        As  string
3217
}
3218

3219
type SelectStmt struct {
3220
        distinct  bool
3221
        targets   []TargetEntry
3222
        selectors []Selector
3223
        ds        DataSource
3224
        indexOn   []string
3225
        joins     []*JoinSpec
3226
        where     ValueExp
3227
        groupBy   []*ColSelector
3228
        having    ValueExp
3229
        orderBy   []*OrdExp
3230
        limit     ValueExp
3231
        offset    ValueExp
3232
        as        string
3233
}
3234

3235
func NewSelectStmt(
3236
        targets []TargetEntry,
3237
        ds DataSource,
3238
        where ValueExp,
3239
        orderBy []*OrdExp,
3240
        limit ValueExp,
3241
        offset ValueExp,
3242
) *SelectStmt {
71✔
3243
        return &SelectStmt{
71✔
3244
                targets: targets,
71✔
3245
                ds:      ds,
71✔
3246
                where:   where,
71✔
3247
                orderBy: orderBy,
71✔
3248
                limit:   limit,
71✔
3249
                offset:  offset,
71✔
3250
        }
71✔
3251
}
71✔
3252

3253
func (stmt *SelectStmt) readOnly() bool {
94✔
3254
        return true
94✔
3255
}
94✔
3256

3257
func (stmt *SelectStmt) requiredPrivileges() []SQLPrivilege {
96✔
3258
        return []SQLPrivilege{SQLPrivilegeSelect}
96✔
3259
}
96✔
3260

3261
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
53✔
3262
        _, err := stmt.execAt(ctx, tx, nil)
53✔
3263
        if err != nil {
53✔
3264
                return err
×
3265
        }
×
3266

3267
        // TODO: (jeroiraz) may be optimized so to resolve the query statement just once
3268
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
53✔
3269
        if err != nil {
54✔
3270
                return err
1✔
3271
        }
1✔
3272
        defer rowReader.Close()
52✔
3273

52✔
3274
        return rowReader.InferParameters(ctx, params)
52✔
3275
}
3276

3277
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
621✔
3278
        if stmt.groupBy == nil && stmt.having != nil {
622✔
3279
                return nil, ErrHavingClauseRequiresGroupClause
1✔
3280
        }
1✔
3281

3282
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
705✔
3283
                for _, sel := range stmt.targetSelectors() {
227✔
3284
                        _, isAgg := sel.(*AggColSelector)
142✔
3285
                        if !isAgg && !stmt.groupByContains(sel) {
144✔
3286
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3287
                        }
2✔
3288
                }
3289
        }
3290

3291
        if len(stmt.orderBy) > 0 {
775✔
3292
                for _, col := range stmt.orderBy {
352✔
3293
                        for _, sel := range col.exp.selectors() {
380✔
3294
                                _, isAgg := sel.(*AggColSelector)
185✔
3295
                                if (isAgg && !stmt.selectorAppearsInTargets(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
187✔
3296
                                        return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3297
                                }
2✔
3298
                        }
3299
                }
3300
        }
3301
        return tx, nil
616✔
3302
}
3303

3304
func (stmt *SelectStmt) targetSelectors() []Selector {
2,524✔
3305
        if stmt.selectors == nil {
3,476✔
3306
                stmt.selectors = stmt.extractSelectors()
952✔
3307
        }
952✔
3308
        return stmt.selectors
2,524✔
3309
}
3310

3311
func (stmt *SelectStmt) selectorAppearsInTargets(s Selector) bool {
4✔
3312
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
3313

4✔
3314
        for _, sel := range stmt.targetSelectors() {
12✔
3315
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
11✔
3316
                        return true
3✔
3317
                }
3✔
3318
        }
3319
        return false
1✔
3320
}
3321

3322
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
3323
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
3324

57✔
3325
        for _, colSel := range stmt.groupBy {
137✔
3326
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
134✔
3327
                        return true
54✔
3328
                }
54✔
3329
        }
3330
        return false
3✔
3331
}
3332

3333
func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector {
80✔
3334
        cols := make([]*AggColSelector, 0, len(stmt.targets))
80✔
3335

80✔
3336
        for _, t := range stmt.targets {
214✔
3337
                selectors := t.Exp.selectors()
134✔
3338
                for _, sel := range selectors {
268✔
3339
                        aggSel, isAgg := sel.(*AggColSelector)
134✔
3340
                        if isAgg {
244✔
3341
                                cols = append(cols, aggSel)
110✔
3342
                        }
110✔
3343
                }
3344
        }
3345
        return cols
80✔
3346
}
3347

3348
func (stmt *SelectStmt) extractSelectors() []Selector {
952✔
3349
        selectors := make([]Selector, 0, len(stmt.targets))
952✔
3350
        for _, t := range stmt.targets {
1,835✔
3351
                selectors = append(selectors, t.Exp.selectors()...)
883✔
3352
        }
883✔
3353
        return selectors
952✔
3354
}
3355

3356
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
959✔
3357
        scanSpecs, err := stmt.genScanSpecs(tx, params)
959✔
3358
        if err != nil {
975✔
3359
                return nil, err
16✔
3360
        }
16✔
3361

3362
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
943✔
3363
        if err != nil {
946✔
3364
                return nil, err
3✔
3365
        }
3✔
3366
        defer func() {
1,880✔
3367
                if err != nil {
947✔
3368
                        rowReader.Close()
7✔
3369
                }
7✔
3370
        }()
3371

3372
        if stmt.joins != nil {
955✔
3373
                var jointRowReader *jointRowReader
15✔
3374
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
15✔
3375
                if err != nil {
16✔
3376
                        return nil, err
1✔
3377
                }
1✔
3378
                rowReader = jointRowReader
14✔
3379
        }
3380

3381
        if stmt.where != nil {
1,468✔
3382
                rowReader = newConditionalRowReader(rowReader, stmt.where)
529✔
3383
        }
529✔
3384

3385
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
1,019✔
3386
                if len(scanSpecs.groupBySortExps) > 0 {
92✔
3387
                        var sortRowReader *sortRowReader
12✔
3388
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortExps)
12✔
3389
                        if err != nil {
12✔
3390
                                return nil, err
×
3391
                        }
×
3392
                        rowReader = sortRowReader
12✔
3393
                }
3394

3395
                var groupedRowReader *groupedRowReader
80✔
3396
                groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy)
80✔
3397
                if err != nil {
82✔
3398
                        return nil, err
2✔
3399
                }
2✔
3400
                rowReader = groupedRowReader
78✔
3401

78✔
3402
                if stmt.having != nil {
82✔
3403
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
3404
                }
4✔
3405
        }
3406

3407
        if len(scanSpecs.orderBySortExps) > 0 {
986✔
3408
                var sortRowReader *sortRowReader
49✔
3409
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
49✔
3410
                if err != nil {
50✔
3411
                        return nil, err
1✔
3412
                }
1✔
3413
                rowReader = sortRowReader
48✔
3414
        }
3415

3416
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets)
936✔
3417
        if err != nil {
939✔
3418
                return nil, err
3✔
3419
        }
3✔
3420
        rowReader = projectedRowReader
933✔
3421

933✔
3422
        if stmt.distinct {
940✔
3423
                var distinctRowReader *distinctRowReader
7✔
3424
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
7✔
3425
                if err != nil {
7✔
3426
                        return nil, err
×
3427
                }
×
3428
                rowReader = distinctRowReader
7✔
3429
        }
3430

3431
        if stmt.offset != nil {
983✔
3432
                var offset int
50✔
3433
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
3434
                if err != nil {
50✔
3435
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
3436
                }
×
3437

3438
                rowReader = newOffsetRowReader(rowReader, offset)
50✔
3439
        }
3440

3441
        if stmt.limit != nil {
1,030✔
3442
                var limit int
97✔
3443
                limit, err = evalExpAsInt(tx, stmt.limit, params)
97✔
3444
                if err != nil {
97✔
3445
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
3446
                }
×
3447

3448
                if limit < 0 {
97✔
3449
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
3450
                }
×
3451

3452
                if limit > 0 {
140✔
3453
                        rowReader = newLimitRowReader(rowReader, limit)
43✔
3454
                }
43✔
3455
        }
3456
        return rowReader, nil
933✔
3457
}
3458

3459
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
942✔
3460
        if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
948✔
3461
                if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
8✔
3462
                        return orderByExps, nil
2✔
3463
                }
2✔
3464

3465
                if ordExpsHasPrefix(groupByCols, orderByExps, stmt.Alias()) {
5✔
3466
                        for i := range orderByExps {
2✔
3467
                                groupByCols[i].descOrder = orderByExps[i].descOrder
1✔
3468
                        }
1✔
3469
                        return groupByCols, nil
1✔
3470
                }
3471
        }
3472
        return groupByCols, orderByExps
939✔
3473
}
3474

3475
func ordExpsHasPrefix(cols, prefix []*OrdExp, table string) bool {
10✔
3476
        if len(prefix) > len(cols) {
12✔
3477
                return false
2✔
3478
        }
2✔
3479

3480
        for i := range prefix {
17✔
3481
                ls := prefix[i].AsSelector()
9✔
3482
                rs := cols[i].AsSelector()
9✔
3483

9✔
3484
                if ls == nil || rs == nil {
9✔
3485
                        return false
×
3486
                }
×
3487

3488
                if EncodeSelector(ls.resolve(table)) != EncodeSelector(rs.resolve(table)) {
14✔
3489
                        return false
5✔
3490
                }
5✔
3491
        }
3492
        return true
3✔
3493
}
3494

3495
func (stmt *SelectStmt) groupByOrdExps() []*OrdExp {
959✔
3496
        groupByCols := stmt.groupBy
959✔
3497

959✔
3498
        ordExps := make([]*OrdExp, 0, len(groupByCols))
959✔
3499
        for _, col := range groupByCols {
1,006✔
3500
                ordExps = append(ordExps, &OrdExp{exp: col})
47✔
3501
        }
47✔
3502
        return ordExps
959✔
3503
}
3504

3505
func ordExpsHaveAggregations(exps []*OrdExp) bool {
7✔
3506
        for _, e := range exps {
17✔
3507
                if _, isAgg := e.exp.(*AggColSelector); isAgg {
11✔
3508
                        return true
1✔
3509
                }
1✔
3510
        }
3511
        return false
6✔
3512
}
3513

3514
func (stmt *SelectStmt) containsAggregations() bool {
1,559✔
3515
        for _, sel := range stmt.targetSelectors() {
3,155✔
3516
                _, isAgg := sel.(*AggColSelector)
1,596✔
3517
                if isAgg {
1,759✔
3518
                        return true
163✔
3519
                }
163✔
3520
        }
3521
        return false
1,396✔
3522
}
3523

3524
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
147✔
3525
        offset, err := exp.substitute(params)
147✔
3526
        if err != nil {
147✔
3527
                return 0, err
×
3528
        }
×
3529

3530
        texp, err := offset.reduce(tx, nil, "")
147✔
3531
        if err != nil {
147✔
3532
                return 0, err
×
3533
        }
×
3534

3535
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
147✔
3536
        if err != nil {
147✔
3537
                return 0, ErrInvalidValue
×
3538
        }
×
3539

3540
        num, ok := convVal.(int64)
147✔
3541
        if !ok {
147✔
3542
                return 0, ErrInvalidValue
×
3543
        }
×
3544

3545
        if num > math.MaxInt32 {
147✔
3546
                return 0, ErrInvalidValue
×
3547
        }
×
3548

3549
        return int(num), nil
147✔
3550
}
3551

3552
func (stmt *SelectStmt) Alias() string {
167✔
3553
        if stmt.as == "" {
333✔
3554
                return stmt.ds.Alias()
166✔
3555
        }
166✔
3556

3557
        return stmt.as
1✔
3558
}
3559

3560
func (stmt *SelectStmt) hasTxMetadata() bool {
876✔
3561
        for _, sel := range stmt.targetSelectors() {
1,692✔
3562
                switch s := sel.(type) {
816✔
3563
                case *ColSelector:
702✔
3564
                        if s.col == txMetadataCol {
703✔
3565
                                return true
1✔
3566
                        }
1✔
3567
                case *JSONSelector:
21✔
3568
                        if s.ColSelector.col == txMetadataCol {
24✔
3569
                                return true
3✔
3570
                        }
3✔
3571
                }
3572
        }
3573
        return false
872✔
3574
}
3575

3576
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
959✔
3577
        groupByCols, orderByCols := stmt.groupByOrdExps(), stmt.orderBy
959✔
3578

959✔
3579
        tableRef, isTableRef := stmt.ds.(*tableRef)
959✔
3580
        if !isTableRef {
1,025✔
3581
                groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
66✔
3582

66✔
3583
                return &ScanSpecs{
66✔
3584
                        groupBySortExps: groupByCols,
66✔
3585
                        orderBySortExps: orderByCols,
66✔
3586
                }, nil
66✔
3587
        }
66✔
3588

3589
        table, err := tableRef.referencedTable(tx)
893✔
3590
        if err != nil {
908✔
3591
                if tx.engine.tableResolveFor(tableRef.table) != nil {
16✔
3592
                        return &ScanSpecs{
1✔
3593
                                groupBySortExps: groupByCols,
1✔
3594
                                orderBySortExps: orderByCols,
1✔
3595
                        }, nil
1✔
3596
                }
1✔
3597
                return nil, err
14✔
3598
        }
3599

3600
        rangesByColID := make(map[uint32]*typedValueRange)
878✔
3601
        if stmt.where != nil {
1,396✔
3602
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
518✔
3603
                if err != nil {
520✔
3604
                        return nil, err
2✔
3605
                }
2✔
3606
        }
3607

3608
        preferredIndex, err := stmt.getPreferredIndex(table)
876✔
3609
        if err != nil {
876✔
3610
                return nil, err
×
3611
        }
×
3612

3613
        var sortingIndex *Index
876✔
3614
        if preferredIndex == nil {
1,722✔
3615
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
846✔
3616
        } else {
876✔
3617
                sortingIndex = preferredIndex
30✔
3618
        }
30✔
3619

3620
        if sortingIndex == nil {
1,629✔
3621
                sortingIndex = table.primaryIndex
753✔
3622
        }
753✔
3623

3624
        if tableRef.history && !sortingIndex.IsPrimary() {
876✔
3625
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3626
        }
×
3627

3628
        var descOrder bool
876✔
3629
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
893✔
3630
                groupByCols = nil
17✔
3631
        }
17✔
3632

3633
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
978✔
3634
                descOrder = orderByCols[0].descOrder
102✔
3635
                orderByCols = nil
102✔
3636
        }
102✔
3637

3638
        groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
876✔
3639

876✔
3640
        return &ScanSpecs{
876✔
3641
                Index:             sortingIndex,
876✔
3642
                rangesByColID:     rangesByColID,
876✔
3643
                IncludeHistory:    tableRef.history,
876✔
3644
                IncludeTxMetadata: stmt.hasTxMetadata(),
876✔
3645
                DescOrder:         descOrder,
876✔
3646
                groupBySortExps:   groupByCols,
876✔
3647
                orderBySortExps:   orderByCols,
876✔
3648
        }, nil
876✔
3649
}
3650

3651
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdExp, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
846✔
3652
        sortCols := groupByCols
846✔
3653
        if len(sortCols) == 0 {
1,666✔
3654
                sortCols = orderByCols
820✔
3655
        }
820✔
3656

3657
        if len(sortCols) == 0 {
1,553✔
3658
                return nil
707✔
3659
        }
707✔
3660

3661
        for _, idx := range table.indexes {
364✔
3662
                if idx.coversOrdCols(sortCols, rangesByColId) {
318✔
3663
                        return idx
93✔
3664
                }
93✔
3665
        }
3666
        return nil
46✔
3667
}
3668

3669
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
876✔
3670
        if len(stmt.indexOn) == 0 {
1,722✔
3671
                return nil, nil
846✔
3672
        }
846✔
3673

3674
        cols := make([]*Column, len(stmt.indexOn))
30✔
3675
        for i, colName := range stmt.indexOn {
80✔
3676
                col, err := table.GetColumnByName(colName)
50✔
3677
                if err != nil {
50✔
3678
                        return nil, err
×
3679
                }
×
3680

3681
                cols[i] = col
50✔
3682
        }
3683
        return table.GetIndexByName(indexName(table.name, cols))
30✔
3684
}
3685

3686
type UnionStmt struct {
3687
        distinct    bool
3688
        left, right DataSource
3689
}
3690

3691
func (stmt *UnionStmt) readOnly() bool {
1✔
3692
        return true
1✔
3693
}
1✔
3694

3695
func (stmt *UnionStmt) requiredPrivileges() []SQLPrivilege {
1✔
3696
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3697
}
1✔
3698

3699
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3700
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
3701
        if err != nil {
1✔
3702
                return err
×
3703
        }
×
3704
        return stmt.right.inferParameters(ctx, tx, params)
1✔
3705
}
3706

3707
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3708
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
3709
        if err != nil {
9✔
3710
                return tx, err
×
3711
        }
×
3712

3713
        return stmt.right.execAt(ctx, tx, params)
9✔
3714
}
3715

3716
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3717
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
11✔
3718
        if err != nil {
12✔
3719
                return nil, err
1✔
3720
        }
1✔
3721
        defer func() {
20✔
3722
                if err != nil {
14✔
3723
                        leftRowReader.Close()
4✔
3724
                }
4✔
3725
        }()
3726

3727
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
10✔
3728
        if err != nil {
11✔
3729
                return nil, err
1✔
3730
        }
1✔
3731
        defer func() {
18✔
3732
                if err != nil {
12✔
3733
                        rightRowReader.Close()
3✔
3734
                }
3✔
3735
        }()
3736

3737
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
9✔
3738
        if err != nil {
12✔
3739
                return nil, err
3✔
3740
        }
3✔
3741

3742
        return rowReader, nil
6✔
3743
}
3744

3745
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3746
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
11✔
3747
        if err != nil {
16✔
3748
                return nil, err
5✔
3749
        }
5✔
3750
        defer func() {
12✔
3751
                if err != nil {
7✔
3752
                        rowReader.Close()
1✔
3753
                }
1✔
3754
        }()
3755

3756
        if stmt.distinct {
11✔
3757
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
5✔
3758
                if err != nil {
6✔
3759
                        return nil, err
1✔
3760
                }
1✔
3761
                rowReader = distinctReader
4✔
3762
        }
3763

3764
        return rowReader, nil
5✔
3765
}
3766

3767
func (stmt *UnionStmt) Alias() string {
×
3768
        return ""
×
3769
}
×
3770

3771
func NewTableRef(table string, as string) *tableRef {
179✔
3772
        return &tableRef{
179✔
3773
                table: table,
179✔
3774
                as:    as,
179✔
3775
        }
179✔
3776
}
179✔
3777

3778
type tableRef struct {
3779
        table   string
3780
        history bool
3781
        period  period
3782
        as      string
3783
}
3784

3785
func (ref *tableRef) readOnly() bool {
1✔
3786
        return true
1✔
3787
}
1✔
3788

3789
func (ref *tableRef) requiredPrivileges() []SQLPrivilege {
1✔
3790
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3791
}
1✔
3792

3793
type period struct {
3794
        start *openPeriod
3795
        end   *openPeriod
3796
}
3797

3798
type openPeriod struct {
3799
        inclusive bool
3800
        instant   periodInstant
3801
}
3802

3803
type periodInstant struct {
3804
        exp         ValueExp
3805
        instantType instantType
3806
}
3807

3808
type instantType = int
3809

3810
const (
3811
        txInstant instantType = iota
3812
        timeInstant
3813
)
3814

3815
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
81✔
3816
        exp, err := i.exp.substitute(params)
81✔
3817
        if err != nil {
81✔
3818
                return 0, err
×
3819
        }
×
3820

3821
        instantVal, err := exp.reduce(tx, nil, "")
81✔
3822
        if err != nil {
83✔
3823
                return 0, err
2✔
3824
        }
2✔
3825

3826
        if i.instantType == txInstant {
124✔
3827
                txID, ok := instantVal.RawValue().(int64)
45✔
3828
                if !ok {
45✔
3829
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3830
                }
×
3831

3832
                if txID <= 0 {
52✔
3833
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3834
                }
7✔
3835

3836
                if inclusive {
61✔
3837
                        return uint64(txID), nil
23✔
3838
                }
23✔
3839

3840
                if asc {
26✔
3841
                        return uint64(txID + 1), nil
11✔
3842
                }
11✔
3843

3844
                if txID <= 1 {
5✔
3845
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3846
                }
1✔
3847

3848
                return uint64(txID - 1), nil
3✔
3849
        } else {
34✔
3850

34✔
3851
                var ts time.Time
34✔
3852

34✔
3853
                if instantVal.Type() == TimestampType {
67✔
3854
                        ts = instantVal.RawValue().(time.Time)
33✔
3855
                } else {
34✔
3856
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
3857
                        if err != nil {
1✔
3858
                                return 0, err
×
3859
                        }
×
3860

3861
                        tval, err := conv(instantVal)
1✔
3862
                        if err != nil {
1✔
3863
                                return 0, err
×
3864
                        }
×
3865

3866
                        ts = tval.RawValue().(time.Time)
1✔
3867
                }
3868

3869
                sts := ts
34✔
3870

34✔
3871
                if asc {
57✔
3872
                        if !inclusive {
34✔
3873
                                sts = sts.Add(1 * time.Second)
11✔
3874
                        }
11✔
3875

3876
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
23✔
3877
                        if err != nil {
34✔
3878
                                return 0, err
11✔
3879
                        }
11✔
3880

3881
                        return txHdr.ID, nil
12✔
3882
                }
3883

3884
                if !inclusive {
11✔
3885
                        sts = sts.Add(-1 * time.Second)
×
3886
                }
×
3887

3888
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
3889
                if err != nil {
11✔
3890
                        return 0, err
×
3891
                }
×
3892

3893
                return txHdr.ID, nil
11✔
3894
        }
3895
}
3896

3897
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
3,943✔
3898
        table, err := tx.catalog.GetTableByName(stmt.table)
3,943✔
3899
        if err != nil {
3,963✔
3900
                return nil, err
20✔
3901
        }
20✔
3902
        return table, nil
3,923✔
3903
}
3904

3905
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3906
        return nil
1✔
3907
}
1✔
3908

3909
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3910
        return tx, nil
×
3911
}
×
3912

3913
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
906✔
3914
        if tx == nil {
906✔
3915
                return nil, ErrIllegalArguments
×
3916
        }
×
3917

3918
        table, err := stmt.referencedTable(tx)
906✔
3919
        if err == nil {
1,811✔
3920
                return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
905✔
3921
        }
905✔
3922

3923
        if resolver := tx.engine.tableResolveFor(stmt.table); resolver != nil {
2✔
3924
                return resolver.Resolve(ctx, tx, stmt.Alias())
1✔
3925
        }
1✔
3926
        return nil, err
×
3927
}
3928

3929
func (stmt *tableRef) Alias() string {
683✔
3930
        if stmt.as == "" {
1,206✔
3931
                return stmt.table
523✔
3932
        }
523✔
3933
        return stmt.as
160✔
3934
}
3935

3936
type valuesDataSource struct {
3937
        inferTypes bool
3938
        rows       []*RowSpec
3939
}
3940

3941
func NewValuesDataSource(rows []*RowSpec) *valuesDataSource {
120✔
3942
        return &valuesDataSource{
120✔
3943
                rows: rows,
120✔
3944
        }
120✔
3945
}
120✔
3946

3947
func (ds *valuesDataSource) readOnly() bool {
×
3948
        return true
×
3949
}
×
3950

3951
func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege {
97✔
3952
        return nil
97✔
3953
}
97✔
3954

3955
func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3956
        return tx, nil
×
3957
}
×
3958

3959
func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
3960
        return nil
×
3961
}
×
3962

3963
func (ds *valuesDataSource) Alias() string {
×
3964
        return ""
×
3965
}
×
3966

3967
func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
2,123✔
3968
        if tx == nil {
2,123✔
3969
                return nil, ErrIllegalArguments
×
3970
        }
×
3971

3972
        cols := make([]ColDescriptor, len(ds.rows[0].Values))
2,123✔
3973
        for i := range cols {
9,775✔
3974
                cols[i] = ColDescriptor{
7,652✔
3975
                        Type:   AnyType,
7,652✔
3976
                        Column: fmt.Sprintf("col%d", i),
7,652✔
3977
                }
7,652✔
3978
        }
7,652✔
3979

3980
        emptyColsDesc, emptyParams := map[string]ColDescriptor{}, map[string]string{}
2,123✔
3981

2,123✔
3982
        if ds.inferTypes {
2,132✔
3983
                for i := 0; i < len(cols); i++ {
48✔
3984
                        t := AnyType
39✔
3985
                        for j := 0; j < len(ds.rows); j++ {
146✔
3986
                                e, err := ds.rows[j].Values[i].substitute(params)
107✔
3987
                                if err != nil {
107✔
3988
                                        return nil, err
×
3989
                                }
×
3990

3991
                                it, err := e.inferType(emptyColsDesc, emptyParams, "")
107✔
3992
                                if err != nil {
107✔
3993
                                        return nil, err
×
3994
                                }
×
3995

3996
                                if t == AnyType {
146✔
3997
                                        t = it
39✔
3998
                                } else if t != it && it != AnyType {
109✔
3999
                                        return nil, fmt.Errorf("cannot match types %s and %s", t, it)
2✔
4000
                                }
2✔
4001
                        }
4002
                        cols[i].Type = t
37✔
4003
                }
4004
        }
4005

4006
        values := make([][]ValueExp, len(ds.rows))
2,121✔
4007
        for i, rowSpec := range ds.rows {
4,351✔
4008
                values[i] = rowSpec.Values
2,230✔
4009
        }
2,230✔
4010
        return NewValuesRowReader(tx, params, cols, ds.inferTypes, "values", values)
2,121✔
4011
}
4012

4013
type JoinSpec struct {
4014
        joinType JoinType
4015
        ds       DataSource
4016
        cond     ValueExp
4017
        indexOn  []string
4018
}
4019

4020
type OrdExp struct {
4021
        exp       ValueExp
4022
        descOrder bool
4023
}
4024

4025
func (oc *OrdExp) AsSelector() Selector {
714✔
4026
        sel, ok := oc.exp.(Selector)
714✔
4027
        if ok {
1,374✔
4028
                return sel
660✔
4029
        }
660✔
4030
        return nil
54✔
4031
}
4032

4033
func NewOrdCol(table string, col string, descOrder bool) *OrdExp {
1✔
4034
        return &OrdExp{
1✔
4035
                exp:       NewColSelector(table, col),
1✔
4036
                descOrder: descOrder,
1✔
4037
        }
1✔
4038
}
1✔
4039

4040
type Selector interface {
4041
        ValueExp
4042
        resolve(implicitTable string) (aggFn, table, col string)
4043
}
4044

4045
type ColSelector struct {
4046
        table string
4047
        col   string
4048
}
4049

4050
func NewColSelector(table, col string) *ColSelector {
126✔
4051
        return &ColSelector{
126✔
4052
                table: table,
126✔
4053
                col:   col,
126✔
4054
        }
126✔
4055
}
126✔
4056

4057
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
867,662✔
4058
        table = implicitTable
867,662✔
4059
        if sel.table != "" {
1,168,674✔
4060
                table = sel.table
301,012✔
4061
        }
301,012✔
4062
        return "", table, sel.col
867,662✔
4063
}
4064

4065
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
685✔
4066
        _, table, col := sel.resolve(implicitTable)
685✔
4067
        encSel := EncodeSelector("", table, col)
685✔
4068

685✔
4069
        desc, ok := cols[encSel]
685✔
4070
        if !ok {
688✔
4071
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
4072
        }
3✔
4073
        return desc.Type, nil
682✔
4074
}
4075

4076
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
4077
        _, table, col := sel.resolve(implicitTable)
15✔
4078
        encSel := EncodeSelector("", table, col)
15✔
4079

15✔
4080
        desc, ok := cols[encSel]
15✔
4081
        if !ok {
17✔
4082
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
4083
        }
2✔
4084

4085
        if desc.Type != t {
16✔
4086
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
4087
        }
3✔
4088

4089
        return nil
10✔
4090
}
4091

4092
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
161,866✔
4093
        return sel, nil
161,866✔
4094
}
161,866✔
4095

4096
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
715,150✔
4097
        if row == nil {
715,151✔
4098
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
4099
        }
1✔
4100

4101
        aggFn, table, col := sel.resolve(implicitTable)
715,149✔
4102

715,149✔
4103
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
715,149✔
4104
        if !ok {
715,156✔
4105
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
4106
        }
7✔
4107
        return v, nil
715,142✔
4108
}
4109

4110
func (sel *ColSelector) selectors() []Selector {
926✔
4111
        return []Selector{sel}
926✔
4112
}
926✔
4113

4114
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
568✔
4115
        aggFn, table, col := sel.resolve(implicitTable)
568✔
4116

568✔
4117
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
568✔
4118
        if !ok {
846✔
4119
                return sel
278✔
4120
        }
278✔
4121

4122
        return v
290✔
4123
}
4124

4125
func (sel *ColSelector) isConstant() bool {
12✔
4126
        return false
12✔
4127
}
12✔
4128

4129
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
4130
        return nil
11✔
4131
}
11✔
4132

4133
func (sel *ColSelector) String() string {
48✔
4134
        return sel.col
48✔
4135
}
48✔
4136

4137
type AggColSelector struct {
4138
        aggFn AggregateFn
4139
        table string
4140
        col   string
4141
}
4142

4143
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
4144
        return &AggColSelector{
16✔
4145
                aggFn: aggFn,
16✔
4146
                table: table,
16✔
4147
                col:   col,
16✔
4148
        }
16✔
4149
}
16✔
4150

4151
func EncodeSelector(aggFn, table, col string) string {
1,417,497✔
4152
        return aggFn + "(" + table + "." + col + ")"
1,417,497✔
4153
}
1,417,497✔
4154

4155
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,586✔
4156
        table = implicitTable
1,586✔
4157
        if sel.table != "" {
1,717✔
4158
                table = sel.table
131✔
4159
        }
131✔
4160
        return sel.aggFn, table, sel.col
1,586✔
4161
}
4162

4163
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
36✔
4164
        if sel.aggFn == COUNT {
55✔
4165
                return IntegerType, nil
19✔
4166
        }
19✔
4167

4168
        colSelector := &ColSelector{table: sel.table, col: sel.col}
17✔
4169

17✔
4170
        if sel.aggFn == SUM || sel.aggFn == AVG {
24✔
4171
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
4172
                if err != nil {
7✔
4173
                        return AnyType, err
×
4174
                }
×
4175

4176
                if t != IntegerType && t != Float64Type {
7✔
4177
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
4178

×
4179
                }
×
4180

4181
                return t, nil
7✔
4182
        }
4183

4184
        return colSelector.inferType(cols, params, implicitTable)
10✔
4185
}
4186

4187
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
4188
        if sel.aggFn == COUNT {
10✔
4189
                if t != IntegerType {
3✔
4190
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4191
                }
1✔
4192
                return nil
1✔
4193
        }
4194

4195
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
4196

6✔
4197
        if sel.aggFn == SUM || sel.aggFn == AVG {
10✔
4198
                if t != IntegerType && t != Float64Type {
5✔
4199
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
4200
                }
1✔
4201
        }
4202

4203
        return colSelector.requiresType(t, cols, params, implicitTable)
5✔
4204
}
4205

4206
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
412✔
4207
        return sel, nil
412✔
4208
}
412✔
4209

4210
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
459✔
4211
        if row == nil {
460✔
4212
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
4213
        }
1✔
4214

4215
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
458✔
4216
        if !ok {
459✔
4217
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
4218
        }
1✔
4219
        return v, nil
457✔
4220
}
4221

4222
func (sel *AggColSelector) selectors() []Selector {
232✔
4223
        return []Selector{sel}
232✔
4224
}
232✔
4225

4226
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4227
        return sel
×
4228
}
×
4229

4230
func (sel *AggColSelector) isConstant() bool {
1✔
4231
        return false
1✔
4232
}
1✔
4233

4234
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4235
        return nil
×
4236
}
×
4237

4238
func (sel *AggColSelector) String() string {
×
4239
        return sel.aggFn + "(" + sel.col + ")"
×
4240
}
×
4241

4242
type NumExp struct {
4243
        op          NumOperator
4244
        left, right ValueExp
4245
}
4246

4247
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
17✔
4248
        // First step - check if we can infer the type of sub-expressions
17✔
4249
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
17✔
4250
        if err != nil {
17✔
4251
                return AnyType, err
×
4252
        }
×
4253
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType {
17✔
4254
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
4255
        }
×
4256

4257
        tright, err := bexp.right.inferType(cols, params, implicitTable)
17✔
4258
        if err != nil {
17✔
4259
                return AnyType, err
×
4260
        }
×
4261
        if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType {
19✔
4262
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
4263
        }
2✔
4264

4265
        if tleft == IntegerType && tright == IntegerType {
19✔
4266
                // Both sides are integer types - the result is also integer
4✔
4267
                return IntegerType, nil
4✔
4268
        }
4✔
4269

4270
        if tleft != AnyType && tright != AnyType {
20✔
4271
                // Both sides have concrete types but at least one of them is float
9✔
4272
                return Float64Type, nil
9✔
4273
        }
9✔
4274

4275
        // Both sides are ambiguous
4276
        return AnyType, nil
2✔
4277
}
4278

4279
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
4280
        ret := make(map[string]SQLValueType, len(params))
11✔
4281
        for k, v := range params {
15✔
4282
                ret[k] = v
4✔
4283
        }
4✔
4284
        return ret
11✔
4285
}
4286

4287
func restoreParams(params, restore map[string]SQLValueType) {
2✔
4288
        for k := range params {
2✔
4289
                delete(params, k)
×
4290
        }
×
4291
        for k, v := range restore {
2✔
4292
                params[k] = v
×
4293
        }
×
4294
}
4295

4296
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4297
        if t != IntegerType && t != Float64Type {
8✔
4298
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4299
        }
1✔
4300

4301
        floatArgs := 2
6✔
4302
        paramsOrig := copyParams(params)
6✔
4303
        err := bexp.left.requiresType(t, cols, params, implicitTable)
6✔
4304
        if err != nil && t == Float64Type {
7✔
4305
                restoreParams(params, paramsOrig)
1✔
4306
                floatArgs--
1✔
4307
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
4308
        }
1✔
4309
        if err != nil {
7✔
4310
                return err
1✔
4311
        }
1✔
4312

4313
        paramsOrig = copyParams(params)
5✔
4314
        err = bexp.right.requiresType(t, cols, params, implicitTable)
5✔
4315
        if err != nil && t == Float64Type {
6✔
4316
                restoreParams(params, paramsOrig)
1✔
4317
                floatArgs--
1✔
4318
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
4319
        }
1✔
4320
        if err != nil {
7✔
4321
                return err
2✔
4322
        }
2✔
4323

4324
        if t == Float64Type && floatArgs == 0 {
3✔
4325
                // Currently this case requires explicit float cast
×
4326
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
4327
        }
×
4328

4329
        return nil
3✔
4330
}
4331

4332
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
187✔
4333
        rlexp, err := bexp.left.substitute(params)
187✔
4334
        if err != nil {
187✔
4335
                return nil, err
×
4336
        }
×
4337

4338
        rrexp, err := bexp.right.substitute(params)
187✔
4339
        if err != nil {
187✔
4340
                return nil, err
×
4341
        }
×
4342

4343
        bexp.left = rlexp
187✔
4344
        bexp.right = rrexp
187✔
4345

187✔
4346
        return bexp, nil
187✔
4347
}
4348

4349
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
124,242✔
4350
        vl, err := bexp.left.reduce(tx, row, implicitTable)
124,242✔
4351
        if err != nil {
124,242✔
4352
                return nil, err
×
4353
        }
×
4354

4355
        vr, err := bexp.right.reduce(tx, row, implicitTable)
124,242✔
4356
        if err != nil {
124,242✔
4357
                return nil, err
×
4358
        }
×
4359

4360
        vl = unwrapJSON(vl)
124,242✔
4361
        vr = unwrapJSON(vr)
124,242✔
4362

124,242✔
4363
        return applyNumOperator(bexp.op, vl, vr)
124,242✔
4364
}
4365

4366
func unwrapJSON(v TypedValue) TypedValue {
248,484✔
4367
        if jsonVal, ok := v.(*JSON); ok {
248,584✔
4368
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
200✔
4369
                        return sv
100✔
4370
                }
100✔
4371
        }
4372
        return v
248,384✔
4373
}
4374

4375
func (bexp *NumExp) selectors() []Selector {
13✔
4376
        return append(bexp.left.selectors(), bexp.right.selectors()...)
13✔
4377
}
13✔
4378

4379
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4380
        return &NumExp{
1✔
4381
                op:    bexp.op,
1✔
4382
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
4383
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
4384
        }
1✔
4385
}
1✔
4386

4387
func (bexp *NumExp) isConstant() bool {
5✔
4388
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
4389
}
5✔
4390

4391
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
4✔
4392
        return nil
4✔
4393
}
4✔
4394

4395
func (bexp *NumExp) String() string {
18✔
4396
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
18✔
4397
}
18✔
4398

4399
type NotBoolExp struct {
4400
        exp ValueExp
4401
}
4402

4403
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4404
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
1✔
4405
        if err != nil {
1✔
4406
                return AnyType, err
×
4407
        }
×
4408

4409
        return BooleanType, nil
1✔
4410
}
4411

4412
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
4413
        if t != BooleanType {
7✔
4414
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4415
        }
1✔
4416

4417
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
5✔
4418
}
4419

4420
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
4421
        rexp, err := bexp.exp.substitute(params)
22✔
4422
        if err != nil {
22✔
4423
                return nil, err
×
4424
        }
×
4425

4426
        bexp.exp = rexp
22✔
4427

22✔
4428
        return bexp, nil
22✔
4429
}
4430

4431
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
4432
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
4433
        if err != nil {
22✔
4434
                return nil, err
×
4435
        }
×
4436

4437
        r, isBool := v.RawValue().(bool)
22✔
4438
        if !isBool {
22✔
4439
                return nil, ErrInvalidCondition
×
4440
        }
×
4441

4442
        return &Bool{val: !r}, nil
22✔
4443
}
4444

4445
func (bexp *NotBoolExp) selectors() []Selector {
×
4446
        return bexp.exp.selectors()
×
4447
}
×
4448

4449
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4450
        return &NotBoolExp{
×
4451
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
4452
        }
×
4453
}
×
4454

4455
func (bexp *NotBoolExp) isConstant() bool {
1✔
4456
        return bexp.exp.isConstant()
1✔
4457
}
1✔
4458

4459
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
4460
        return nil
7✔
4461
}
7✔
4462

4463
func (bexp *NotBoolExp) String() string {
12✔
4464
        return fmt.Sprintf("(NOT %s)", bexp.exp.String())
12✔
4465
}
12✔
4466

4467
type LikeBoolExp struct {
4468
        val     ValueExp
4469
        notLike bool
4470
        pattern ValueExp
4471
}
4472

4473
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
4✔
4474
        return &LikeBoolExp{
4✔
4475
                val:     val,
4✔
4476
                notLike: notLike,
4✔
4477
                pattern: pattern,
4✔
4478
        }
4✔
4479
}
4✔
4480

4481
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4482
        if bexp.val == nil || bexp.pattern == nil {
4✔
4483
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4484
        }
1✔
4485

4486
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
2✔
4487
        if err != nil {
3✔
4488
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4489
        }
1✔
4490

4491
        return BooleanType, nil
1✔
4492
}
4493

4494
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4495
        if bexp.val == nil || bexp.pattern == nil {
9✔
4496
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
4497
        }
2✔
4498

4499
        if t != BooleanType {
7✔
4500
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
4501
        }
2✔
4502

4503
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
4504
        if err != nil {
4✔
4505
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4506
        }
1✔
4507

4508
        return nil
2✔
4509
}
4510

4511
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
135✔
4512
        if bexp.val == nil || bexp.pattern == nil {
136✔
4513
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4514
        }
1✔
4515

4516
        val, err := bexp.val.substitute(params)
134✔
4517
        if err != nil {
134✔
4518
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4519
        }
×
4520

4521
        pattern, err := bexp.pattern.substitute(params)
134✔
4522
        if err != nil {
134✔
4523
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4524
        }
×
4525

4526
        return &LikeBoolExp{
134✔
4527
                val:     val,
134✔
4528
                notLike: bexp.notLike,
134✔
4529
                pattern: pattern,
134✔
4530
        }, nil
134✔
4531
}
4532

4533
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
143✔
4534
        if bexp.val == nil || bexp.pattern == nil {
144✔
4535
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4536
        }
1✔
4537

4538
        rval, err := bexp.val.reduce(tx, row, implicitTable)
142✔
4539
        if err != nil {
142✔
4540
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4541
        }
×
4542

4543
        if rval.IsNull() {
143✔
4544
                return &Bool{val: bexp.notLike}, nil
1✔
4545
        }
1✔
4546

4547
        rvalStr, ok := rval.RawValue().(string)
141✔
4548
        if !ok {
142✔
4549
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
4550
        }
1✔
4551

4552
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
140✔
4553
        if err != nil {
140✔
4554
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4555
        }
×
4556

4557
        if rpattern.Type() != VarcharType {
140✔
4558
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
4559
        }
×
4560

4561
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
140✔
4562
        if err != nil {
140✔
4563
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4564
        }
×
4565

4566
        return &Bool{val: matched != bexp.notLike}, nil
140✔
4567
}
4568

4569
func (bexp *LikeBoolExp) selectors() []Selector {
1✔
4570
        return bexp.val.selectors()
1✔
4571
}
1✔
4572

4573
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4574
        return bexp
1✔
4575
}
1✔
4576

4577
func (bexp *LikeBoolExp) isConstant() bool {
2✔
4578
        return false
2✔
4579
}
2✔
4580

4581
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
4582
        return nil
8✔
4583
}
8✔
4584

4585
func (bexp *LikeBoolExp) String() string {
5✔
4586
        fmtStr := "(%s LIKE %s)"
5✔
4587
        if bexp.notLike {
6✔
4588
                fmtStr = "(%s NOT LIKE %s)"
1✔
4589
        }
1✔
4590
        return fmt.Sprintf(fmtStr, bexp.val.String(), bexp.pattern.String())
5✔
4591
}
4592

4593
type CmpBoolExp struct {
4594
        op          CmpOperator
4595
        left, right ValueExp
4596
}
4597

4598
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
67✔
4599
        return &CmpBoolExp{
67✔
4600
                op:    op,
67✔
4601
                left:  left,
67✔
4602
                right: right,
67✔
4603
        }
67✔
4604
}
67✔
4605

4606
func (bexp *CmpBoolExp) Left() ValueExp {
×
4607
        return bexp.left
×
4608
}
×
4609

4610
func (bexp *CmpBoolExp) Right() ValueExp {
×
4611
        return bexp.right
×
4612
}
×
4613

4614
func (bexp *CmpBoolExp) OP() CmpOperator {
×
4615
        return bexp.op
×
4616
}
×
4617

4618
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
63✔
4619
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
63✔
4620
        if err != nil {
63✔
4621
                return AnyType, err
×
4622
        }
×
4623

4624
        tright, err := bexp.right.inferType(cols, params, implicitTable)
63✔
4625
        if err != nil {
65✔
4626
                return AnyType, err
2✔
4627
        }
2✔
4628

4629
        // unification step
4630

4631
        if tleft == tright {
74✔
4632
                return BooleanType, nil
13✔
4633
        }
13✔
4634

4635
        _, ok := coerceTypes(tleft, tright)
48✔
4636
        if !ok {
52✔
4637
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4638
        }
4✔
4639

4640
        if tleft == AnyType {
47✔
4641
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
3✔
4642
                if err != nil {
3✔
4643
                        return AnyType, err
×
4644
                }
×
4645
        }
4646

4647
        if tright == AnyType {
84✔
4648
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
40✔
4649
                if err != nil {
41✔
4650
                        return AnyType, err
1✔
4651
                }
1✔
4652
        }
4653
        return BooleanType, nil
43✔
4654
}
4655

4656
func coerceTypes(t1, t2 SQLValueType) (SQLValueType, bool) {
48✔
4657
        switch {
48✔
4658
        case t1 == t2:
×
4659
                return t1, true
×
4660
        case t1 == AnyType:
3✔
4661
                return t2, true
3✔
4662
        case t2 == AnyType:
40✔
4663
                return t1, true
40✔
4664
        case (t1 == IntegerType && t2 == Float64Type) ||
4665
                (t1 == Float64Type && t2 == IntegerType):
1✔
4666
                return Float64Type, true
1✔
4667
        }
4668
        return "", false
4✔
4669
}
4670

4671
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
41✔
4672
        if t != BooleanType {
42✔
4673
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4674
        }
1✔
4675

4676
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4677
        return err
40✔
4678
}
4679

4680
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
14,324✔
4681
        rlexp, err := bexp.left.substitute(params)
14,324✔
4682
        if err != nil {
14,324✔
4683
                return nil, err
×
4684
        }
×
4685

4686
        rrexp, err := bexp.right.substitute(params)
14,324✔
4687
        if err != nil {
14,325✔
4688
                return nil, err
1✔
4689
        }
1✔
4690

4691
        bexp.left = rlexp
14,323✔
4692
        bexp.right = rrexp
14,323✔
4693

14,323✔
4694
        return bexp, nil
14,323✔
4695
}
4696

4697
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
13,924✔
4698
        vl, err := bexp.left.reduce(tx, row, implicitTable)
13,924✔
4699
        if err != nil {
13,926✔
4700
                return nil, err
2✔
4701
        }
2✔
4702

4703
        vr, err := bexp.right.reduce(tx, row, implicitTable)
13,922✔
4704
        if err != nil {
13,924✔
4705
                return nil, err
2✔
4706
        }
2✔
4707

4708
        r, err := vl.Compare(vr)
13,920✔
4709
        if err != nil {
13,925✔
4710
                return nil, err
5✔
4711
        }
5✔
4712

4713
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
13,915✔
4714
}
4715

4716
func (bexp *CmpBoolExp) selectors() []Selector {
12✔
4717
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12✔
4718
}
12✔
4719

4720
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
282✔
4721
        return &CmpBoolExp{
282✔
4722
                op:    bexp.op,
282✔
4723
                left:  bexp.left.reduceSelectors(row, implicitTable),
282✔
4724
                right: bexp.right.reduceSelectors(row, implicitTable),
282✔
4725
        }
282✔
4726
}
282✔
4727

4728
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4729
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4730
}
2✔
4731

4732
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
607✔
4733
        matchingFunc := func(_, right ValueExp) (*ColSelector, ValueExp, bool) {
1,423✔
4734
                s, isSel := bexp.left.(*ColSelector)
816✔
4735
                if isSel && s.col != revCol && bexp.right.isConstant() {
1,214✔
4736
                        return s, right, true
398✔
4737
                }
398✔
4738
                return nil, nil, false
418✔
4739
        }
4740

4741
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
607✔
4742
        if !ok {
816✔
4743
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
209✔
4744
        }
209✔
4745

4746
        if !ok {
816✔
4747
                return nil
209✔
4748
        }
209✔
4749

4750
        aggFn, t, col := sel.resolve(table.name)
398✔
4751
        if aggFn != "" || t != asTable {
412✔
4752
                return nil
14✔
4753
        }
14✔
4754

4755
        column, err := table.GetColumnByName(col)
384✔
4756
        if err != nil {
385✔
4757
                return err
1✔
4758
        }
1✔
4759

4760
        val, err := c.substitute(params)
383✔
4761
        if errors.Is(err, ErrMissingParameter) {
442✔
4762
                // TODO: not supported when parameters are not provided during query resolution
59✔
4763
                return nil
59✔
4764
        }
59✔
4765
        if err != nil {
324✔
4766
                return err
×
4767
        }
×
4768

4769
        rval, err := val.reduce(nil, nil, table.name)
324✔
4770
        if err != nil {
325✔
4771
                return err
1✔
4772
        }
1✔
4773

4774
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
323✔
4775
}
4776

4777
func (bexp *CmpBoolExp) String() string {
20✔
4778
        opStr := CmpOperatorToString(bexp.op)
20✔
4779
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
20✔
4780
}
20✔
4781

4782
type TimestampFieldType string
4783

4784
const (
4785
        TimestampFieldTypeYear   TimestampFieldType = "YEAR"
4786
        TimestampFieldTypeMonth  TimestampFieldType = "MONTH"
4787
        TimestampFieldTypeDay    TimestampFieldType = "DAY"
4788
        TimestampFieldTypeHour   TimestampFieldType = "HOUR"
4789
        TimestampFieldTypeMinute TimestampFieldType = "MINUTE"
4790
        TimestampFieldTypeSecond TimestampFieldType = "SECOND"
4791
)
4792

4793
type ExtractFromTimestampExp struct {
4794
        Field TimestampFieldType
4795
        Exp   ValueExp
4796
}
4797

4798
func (te *ExtractFromTimestampExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4799
        inferredType, err := te.Exp.inferType(cols, params, implicitTable)
3✔
4800
        if err != nil {
3✔
NEW
4801
                return "", err
×
NEW
4802
        }
×
4803

4804
        if inferredType != TimestampType &&
3✔
4805
                inferredType != VarcharType &&
3✔
4806
                inferredType != AnyType {
3✔
NEW
4807
                return "", fmt.Errorf("timestamp expression must be of type %v or %v, but was: %v", TimestampType, VarcharType, inferredType)
×
NEW
4808
        }
×
4809
        return IntegerType, nil
3✔
4810
}
4811

4812
func (te *ExtractFromTimestampExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
4813
        if t != IntegerType && t != Float64Type {
4✔
NEW
4814
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
×
NEW
4815
        }
×
4816
        return te.Exp.requiresType(TimestampType, cols, params, implicitTable)
4✔
4817
}
4818

4819
func (te *ExtractFromTimestampExp) substitute(params map[string]interface{}) (ValueExp, error) {
18✔
4820
        exp, err := te.Exp.substitute(params)
18✔
4821
        if err != nil {
18✔
NEW
4822
                return nil, err
×
NEW
4823
        }
×
4824
        return &ExtractFromTimestampExp{
18✔
4825
                Field: te.Field,
18✔
4826
                Exp:   exp,
18✔
4827
        }, nil
18✔
4828
}
4829

4830
func (te *ExtractFromTimestampExp) selectors() []Selector {
12✔
4831
        return te.Exp.selectors()
12✔
4832
}
12✔
4833

4834
func (te *ExtractFromTimestampExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
18✔
4835
        v, err := te.Exp.reduce(tx, row, implicitTable)
18✔
4836
        if err != nil {
18✔
NEW
4837
                return nil, err
×
NEW
4838
        }
×
4839

4840
        if v.IsNull() {
18✔
NEW
4841
                return NewNull(IntegerType), nil
×
NEW
4842
        }
×
4843

4844
        if t := v.Type(); t != TimestampType && t != VarcharType {
18✔
NEW
4845
                return nil, fmt.Errorf("%w: expected type %v but found type %v", ErrInvalidTypes, TimestampType, t)
×
NEW
4846
        }
×
4847

4848
        if v.Type() == VarcharType {
22✔
4849
                converterFunc, err := getConverter(VarcharType, TimestampType)
4✔
4850
                if err != nil {
4✔
NEW
4851
                        return nil, err
×
NEW
4852
                }
×
4853
                casted, err := converterFunc(v)
4✔
4854
                if err != nil {
4✔
NEW
4855
                        return nil, err
×
NEW
4856
                }
×
4857
                v = casted
4✔
4858
        }
4859

4860
        t, _ := v.RawValue().(time.Time)
18✔
4861

18✔
4862
        year, month, day := t.Date()
18✔
4863

18✔
4864
        switch te.Field {
18✔
4865
        case TimestampFieldTypeYear:
3✔
4866
                return NewInteger(int64(year)), nil
3✔
4867
        case TimestampFieldTypeMonth:
3✔
4868
                return NewInteger(int64(month)), nil
3✔
4869
        case TimestampFieldTypeDay:
3✔
4870
                return NewInteger(int64(day)), nil
3✔
4871
        case TimestampFieldTypeHour:
3✔
4872
                return NewInteger(int64(t.Hour())), nil
3✔
4873
        case TimestampFieldTypeMinute:
3✔
4874
                return NewInteger(int64(t.Minute())), nil
3✔
4875
        case TimestampFieldTypeSecond:
3✔
4876
                return NewInteger(int64(t.Second())), nil
3✔
4877
        }
NEW
4878
        return nil, fmt.Errorf("unknown timestamp field type: %s", te.Field)
×
4879
}
4880

NEW
4881
func (te *ExtractFromTimestampExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
NEW
4882
        return &ExtractFromTimestampExp{
×
NEW
4883
                Field: te.Field,
×
NEW
4884
                Exp:   te.Exp.reduceSelectors(row, implicitTable),
×
NEW
4885
        }
×
NEW
4886
}
×
4887

4888
func (te *ExtractFromTimestampExp) isConstant() bool {
1✔
4889
        return false
1✔
4890
}
1✔
4891

NEW
4892
func (te *ExtractFromTimestampExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
NEW
4893
        return nil
×
NEW
4894
}
×
4895

4896
func (te *ExtractFromTimestampExp) String() string {
6✔
4897
        return fmt.Sprintf("EXTRACT(%s FROM %s)", te.Field, te.Exp)
6✔
4898
}
6✔
4899

4900
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
323✔
4901
        currRange, ranged := rangesByColID[colID]
323✔
4902
        var newRange *typedValueRange
323✔
4903

323✔
4904
        switch cmp {
323✔
4905
        case EQ:
250✔
4906
                {
500✔
4907
                        newRange = &typedValueRange{
250✔
4908
                                lRange: &typedValueSemiRange{
250✔
4909
                                        val:       val,
250✔
4910
                                        inclusive: true,
250✔
4911
                                },
250✔
4912
                                hRange: &typedValueSemiRange{
250✔
4913
                                        val:       val,
250✔
4914
                                        inclusive: true,
250✔
4915
                                },
250✔
4916
                        }
250✔
4917
                }
250✔
4918
        case LT:
13✔
4919
                {
26✔
4920
                        newRange = &typedValueRange{
13✔
4921
                                hRange: &typedValueSemiRange{
13✔
4922
                                        val: val,
13✔
4923
                                },
13✔
4924
                        }
13✔
4925
                }
13✔
4926
        case LE:
12✔
4927
                {
24✔
4928
                        newRange = &typedValueRange{
12✔
4929
                                hRange: &typedValueSemiRange{
12✔
4930
                                        val:       val,
12✔
4931
                                        inclusive: true,
12✔
4932
                                },
12✔
4933
                        }
12✔
4934
                }
12✔
4935
        case GT:
18✔
4936
                {
36✔
4937
                        newRange = &typedValueRange{
18✔
4938
                                lRange: &typedValueSemiRange{
18✔
4939
                                        val: val,
18✔
4940
                                },
18✔
4941
                        }
18✔
4942
                }
18✔
4943
        case GE:
18✔
4944
                {
36✔
4945
                        newRange = &typedValueRange{
18✔
4946
                                lRange: &typedValueSemiRange{
18✔
4947
                                        val:       val,
18✔
4948
                                        inclusive: true,
18✔
4949
                                },
18✔
4950
                        }
18✔
4951
                }
18✔
4952
        case NE:
12✔
4953
                {
24✔
4954
                        return nil
12✔
4955
                }
12✔
4956
        }
4957

4958
        if !ranged {
617✔
4959
                rangesByColID[colID] = newRange
306✔
4960
                return nil
306✔
4961
        }
306✔
4962

4963
        return currRange.refineWith(newRange)
5✔
4964
}
4965

4966
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
13,915✔
4967
        switch {
13,915✔
4968
        case cmp == 0:
1,166✔
4969
                {
2,332✔
4970
                        return op == EQ || op == LE || op == GE
1,166✔
4971
                }
1,166✔
4972
        case cmp < 0:
6,491✔
4973
                {
12,982✔
4974
                        return op == NE || op == LT || op == LE
6,491✔
4975
                }
6,491✔
4976
        case cmp > 0:
6,258✔
4977
                {
12,516✔
4978
                        return op == NE || op == GT || op == GE
6,258✔
4979
                }
6,258✔
4980
        }
4981
        return false
×
4982
}
4983

4984
type BinBoolExp struct {
4985
        op          LogicOperator
4986
        left, right ValueExp
4987
}
4988

4989
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
18✔
4990
        bexp := &BinBoolExp{
18✔
4991
                op: op,
18✔
4992
        }
18✔
4993

18✔
4994
        bexp.left = lrexp
18✔
4995
        bexp.right = rrexp
18✔
4996

18✔
4997
        return bexp
18✔
4998
}
18✔
4999

5000
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
20✔
5001
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
20✔
5002
        if err != nil {
20✔
5003
                return AnyType, err
×
5004
        }
×
5005

5006
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
20✔
5007
        if err != nil {
22✔
5008
                return AnyType, err
2✔
5009
        }
2✔
5010

5011
        return BooleanType, nil
18✔
5012
}
5013

5014
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
22✔
5015
        if t != BooleanType {
25✔
5016
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
3✔
5017
        }
3✔
5018

5019
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
5020
        if err != nil {
20✔
5021
                return err
1✔
5022
        }
1✔
5023

5024
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
18✔
5025
        if err != nil {
18✔
5026
                return err
×
5027
        }
×
5028

5029
        return nil
18✔
5030
}
5031

5032
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
576✔
5033
        rlexp, err := bexp.left.substitute(params)
576✔
5034
        if err != nil {
576✔
5035
                return nil, err
×
5036
        }
×
5037

5038
        rrexp, err := bexp.right.substitute(params)
576✔
5039
        if err != nil {
576✔
5040
                return nil, err
×
5041
        }
×
5042

5043
        bexp.left = rlexp
576✔
5044
        bexp.right = rrexp
576✔
5045

576✔
5046
        return bexp, nil
576✔
5047
}
5048

5049
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
541✔
5050
        vl, err := bexp.left.reduce(tx, row, implicitTable)
541✔
5051
        if err != nil {
542✔
5052
                return nil, err
1✔
5053
        }
1✔
5054

5055
        bl, isBool := vl.(*Bool)
540✔
5056
        if !isBool {
540✔
5057
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5058
        }
×
5059

5060
        // short-circuit evaluation
5061
        if (bl.val && bexp.op == Or) || (!bl.val && bexp.op == And) {
716✔
5062
                return &Bool{val: bl.val}, nil
176✔
5063
        }
176✔
5064

5065
        vr, err := bexp.right.reduce(tx, row, implicitTable)
364✔
5066
        if err != nil {
365✔
5067
                return nil, err
1✔
5068
        }
1✔
5069

5070
        br, isBool := vr.(*Bool)
363✔
5071
        if !isBool {
363✔
5072
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5073
        }
×
5074

5075
        switch bexp.op {
363✔
5076
        case And:
340✔
5077
                {
680✔
5078
                        return &Bool{val: bl.val && br.val}, nil
340✔
5079
                }
340✔
5080
        case Or:
23✔
5081
                {
46✔
5082
                        return &Bool{val: bl.val || br.val}, nil
23✔
5083
                }
23✔
5084
        }
5085

5086
        return nil, ErrUnexpected
×
5087
}
5088

5089
func (bexp *BinBoolExp) selectors() []Selector {
2✔
5090
        return append(bexp.left.selectors(), bexp.right.selectors()...)
2✔
5091
}
2✔
5092

5093
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
5094
        return &BinBoolExp{
15✔
5095
                op:    bexp.op,
15✔
5096
                left:  bexp.left.reduceSelectors(row, implicitTable),
15✔
5097
                right: bexp.right.reduceSelectors(row, implicitTable),
15✔
5098
        }
15✔
5099
}
15✔
5100

5101
func (bexp *BinBoolExp) isConstant() bool {
1✔
5102
        return bexp.left.isConstant() && bexp.right.isConstant()
1✔
5103
}
1✔
5104

5105
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
153✔
5106
        if bexp.op == And {
292✔
5107
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
139✔
5108
                if err != nil {
139✔
5109
                        return err
×
5110
                }
×
5111

5112
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
139✔
5113
        }
5114

5115
        lRanges := make(map[uint32]*typedValueRange)
14✔
5116
        rRanges := make(map[uint32]*typedValueRange)
14✔
5117

14✔
5118
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
14✔
5119
        if err != nil {
14✔
5120
                return err
×
5121
        }
×
5122

5123
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
14✔
5124
        if err != nil {
14✔
5125
                return err
×
5126
        }
×
5127

5128
        for colID, lr := range lRanges {
21✔
5129
                rr, ok := rRanges[colID]
7✔
5130
                if !ok {
9✔
5131
                        continue
2✔
5132
                }
5133

5134
                err = lr.extendWith(rr)
5✔
5135
                if err != nil {
5✔
5136
                        return err
×
5137
                }
×
5138

5139
                rangesByColID[colID] = lr
5✔
5140
        }
5141

5142
        return nil
14✔
5143
}
5144

5145
func (bexp *BinBoolExp) String() string {
31✔
5146
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
31✔
5147
}
31✔
5148

5149
type ExistsBoolExp struct {
5150
        q DataSource
5151
}
5152

5153
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5154
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5155
}
1✔
5156

5157
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5158
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5159
}
1✔
5160

5161
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5162
        return bexp, nil
1✔
5163
}
1✔
5164

5165
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5166
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
2✔
5167
}
2✔
5168

5169
func (bexp *ExistsBoolExp) selectors() []Selector {
1✔
5170
        return nil
1✔
5171
}
1✔
5172

5173
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5174
        return bexp
1✔
5175
}
1✔
5176

5177
func (bexp *ExistsBoolExp) isConstant() bool {
2✔
5178
        return false
2✔
5179
}
2✔
5180

5181
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5182
        return nil
1✔
5183
}
1✔
5184

5185
func (bexp *ExistsBoolExp) String() string {
×
5186
        return ""
×
5187
}
×
5188

5189
type InSubQueryExp struct {
5190
        val   ValueExp
5191
        notIn bool
5192
        q     *SelectStmt
5193
}
5194

5195
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5196
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5197
}
1✔
5198

5199
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5200
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5201
}
1✔
5202

5203
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5204
        return bexp, nil
1✔
5205
}
1✔
5206

5207
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5208
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
5209
}
2✔
5210

5211
func (bexp *InSubQueryExp) selectors() []Selector {
1✔
5212
        return bexp.val.selectors()
1✔
5213
}
1✔
5214

5215
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5216
        return bexp
1✔
5217
}
1✔
5218

5219
func (bexp *InSubQueryExp) isConstant() bool {
1✔
5220
        return false
1✔
5221
}
1✔
5222

5223
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5224
        return nil
1✔
5225
}
1✔
5226

5227
func (bexp *InSubQueryExp) String() string {
×
5228
        return ""
×
5229
}
×
5230

5231
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
5232
type InListExp struct {
5233
        val    ValueExp
5234
        notIn  bool
5235
        values []ValueExp
5236
}
5237

5238
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
5239
        t, err := bexp.val.inferType(cols, params, implicitTable)
6✔
5240
        if err != nil {
8✔
5241
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
2✔
5242
        }
2✔
5243

5244
        for _, v := range bexp.values {
12✔
5245
                err = v.requiresType(t, cols, params, implicitTable)
8✔
5246
                if err != nil {
9✔
5247
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
5248
                }
1✔
5249
        }
5250

5251
        return BooleanType, nil
3✔
5252
}
5253

5254
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
5255
        _, err := bexp.inferType(cols, params, implicitTable)
2✔
5256
        if err != nil {
3✔
5257
                return err
1✔
5258
        }
1✔
5259

5260
        if t != BooleanType {
1✔
5261
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
×
5262
        }
×
5263

5264
        return nil
1✔
5265
}
5266

5267
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
115✔
5268
        val, err := bexp.val.substitute(params)
115✔
5269
        if err != nil {
115✔
5270
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5271
        }
×
5272

5273
        values := make([]ValueExp, len(bexp.values))
115✔
5274

115✔
5275
        for i, val := range bexp.values {
245✔
5276
                values[i], err = val.substitute(params)
130✔
5277
                if err != nil {
130✔
5278
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5279
                }
×
5280
        }
5281

5282
        return &InListExp{
115✔
5283
                val:    val,
115✔
5284
                notIn:  bexp.notIn,
115✔
5285
                values: values,
115✔
5286
        }, nil
115✔
5287
}
5288

5289
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
115✔
5290
        rval, err := bexp.val.reduce(tx, row, implicitTable)
115✔
5291
        if err != nil {
116✔
5292
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5293
        }
1✔
5294

5295
        var found bool
114✔
5296

114✔
5297
        for _, v := range bexp.values {
241✔
5298
                rv, err := v.reduce(tx, row, implicitTable)
127✔
5299
                if err != nil {
128✔
5300
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5301
                }
1✔
5302

5303
                r, err := rval.Compare(rv)
126✔
5304
                if err != nil {
127✔
5305
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5306
                }
1✔
5307

5308
                if r == 0 {
140✔
5309
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
15✔
5310
                        found = found || true
15✔
5311
                }
15✔
5312
        }
5313

5314
        return &Bool{val: found != bexp.notIn}, nil
112✔
5315
}
5316

5317
func (bexp *InListExp) selectors() []Selector {
1✔
5318
        selectors := make([]Selector, 0, len(bexp.values))
1✔
5319
        for _, v := range bexp.values {
4✔
5320
                selectors = append(selectors, v.selectors()...)
3✔
5321
        }
3✔
5322
        return append(bexp.val.selectors(), selectors...)
1✔
5323
}
5324

5325
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
5326
        values := make([]ValueExp, len(bexp.values))
10✔
5327

10✔
5328
        for i, val := range bexp.values {
20✔
5329
                values[i] = val.reduceSelectors(row, implicitTable)
10✔
5330
        }
10✔
5331

5332
        return &InListExp{
10✔
5333
                val:    bexp.val.reduceSelectors(row, implicitTable),
10✔
5334
                values: values,
10✔
5335
        }
10✔
5336
}
5337

5338
func (bexp *InListExp) isConstant() bool {
1✔
5339
        return false
1✔
5340
}
1✔
5341

5342
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
21✔
5343
        // TODO: may be determiined by smallest and bigggest value in the list
21✔
5344
        return nil
21✔
5345
}
21✔
5346

5347
func (bexp *InListExp) String() string {
1✔
5348
        values := make([]string, len(bexp.values))
1✔
5349
        for i, exp := range bexp.values {
5✔
5350
                values[i] = exp.String()
4✔
5351
        }
4✔
5352
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
1✔
5353
}
5354

5355
type FnDataSourceStmt struct {
5356
        fnCall *FnCall
5357
        as     string
5358
}
5359

5360
func (stmt *FnDataSourceStmt) readOnly() bool {
1✔
5361
        return true
1✔
5362
}
1✔
5363

5364
func (stmt *FnDataSourceStmt) requiredPrivileges() []SQLPrivilege {
1✔
5365
        return nil
1✔
5366
}
1✔
5367

5368
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
5369
        return tx, nil
×
5370
}
×
5371

5372
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5373
        return nil
1✔
5374
}
1✔
5375

5376
func (stmt *FnDataSourceStmt) Alias() string {
24✔
5377
        if stmt.as != "" {
26✔
5378
                return stmt.as
2✔
5379
        }
2✔
5380

5381
        switch strings.ToUpper(stmt.fnCall.fn) {
22✔
5382
        case DatabasesFnCall:
3✔
5383
                {
6✔
5384
                        return "databases"
3✔
5385
                }
3✔
5386
        case TablesFnCall:
5✔
5387
                {
10✔
5388
                        return "tables"
5✔
5389
                }
5✔
5390
        case TableFnCall:
×
5391
                {
×
5392
                        return "table"
×
5393
                }
×
5394
        case UsersFnCall:
7✔
5395
                {
14✔
5396
                        return "users"
7✔
5397
                }
7✔
5398
        case ColumnsFnCall:
3✔
5399
                {
6✔
5400
                        return "columns"
3✔
5401
                }
3✔
5402
        case IndexesFnCall:
2✔
5403
                {
4✔
5404
                        return "indexes"
2✔
5405
                }
2✔
5406
        case GrantsFnCall:
2✔
5407
                return "grants"
2✔
5408
        }
5409

5410
        // not reachable
5411
        return ""
×
5412
}
5413

5414
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
25✔
5415
        if stmt.fnCall == nil {
25✔
5416
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
×
5417
        }
×
5418

5419
        switch strings.ToUpper(stmt.fnCall.fn) {
25✔
5420
        case DatabasesFnCall:
5✔
5421
                {
10✔
5422
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
5✔
5423
                }
5✔
5424
        case TablesFnCall:
5✔
5425
                {
10✔
5426
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
5✔
5427
                }
5✔
5428
        case TableFnCall:
×
5429
                {
×
5430
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
×
5431
                }
×
5432
        case UsersFnCall:
7✔
5433
                {
14✔
5434
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
7✔
5435
                }
7✔
5436
        case ColumnsFnCall:
3✔
5437
                {
6✔
5438
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
5439
                }
3✔
5440
        case IndexesFnCall:
3✔
5441
                {
6✔
5442
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
3✔
5443
                }
3✔
5444
        case GrantsFnCall:
2✔
5445
                {
4✔
5446
                        return stmt.resolveListGrants(ctx, tx, params, scanSpecs)
2✔
5447
                }
2✔
5448
        }
5449

5450
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
×
5451
}
5452

5453
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5454
        if len(stmt.fnCall.params) > 0 {
5✔
5455
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
×
5456
        }
×
5457

5458
        cols := make([]ColDescriptor, 1)
5✔
5459
        cols[0] = ColDescriptor{
5✔
5460
                Column: "name",
5✔
5461
                Type:   VarcharType,
5✔
5462
        }
5✔
5463

5✔
5464
        var dbs []string
5✔
5465

5✔
5466
        if tx.engine.multidbHandler == nil {
6✔
5467
                return nil, ErrUnspecifiedMultiDBHandler
1✔
5468
        } else {
5✔
5469
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
4✔
5470
                if err != nil {
4✔
5471
                        return nil, err
×
5472
                }
×
5473
        }
5474

5475
        values := make([][]ValueExp, len(dbs))
4✔
5476

4✔
5477
        for i, db := range dbs {
12✔
5478
                values[i] = []ValueExp{&Varchar{val: db}}
8✔
5479
        }
8✔
5480

5481
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
4✔
5482
}
5483

5484
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5485
        if len(stmt.fnCall.params) > 0 {
5✔
5486
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
5487
        }
×
5488

5489
        cols := make([]ColDescriptor, 1)
5✔
5490
        cols[0] = ColDescriptor{
5✔
5491
                Column: "name",
5✔
5492
                Type:   VarcharType,
5✔
5493
        }
5✔
5494

5✔
5495
        tables := tx.catalog.GetTables()
5✔
5496

5✔
5497
        values := make([][]ValueExp, len(tables))
5✔
5498

5✔
5499
        for i, t := range tables {
14✔
5500
                values[i] = []ValueExp{&Varchar{val: t.name}}
9✔
5501
        }
9✔
5502

5503
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5504
}
5505

5506
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
×
5507
        cols := []ColDescriptor{
×
5508
                {
×
5509
                        Column: "column_name",
×
5510
                        Type:   VarcharType,
×
5511
                },
×
5512
                {
×
5513
                        Column: "type_name",
×
5514
                        Type:   VarcharType,
×
5515
                },
×
5516
                {
×
5517
                        Column: "is_nullable",
×
5518
                        Type:   BooleanType,
×
5519
                },
×
5520
                {
×
5521
                        Column: "is_indexed",
×
5522
                        Type:   VarcharType,
×
5523
                },
×
5524
                {
×
5525
                        Column: "is_auto_increment",
×
5526
                        Type:   BooleanType,
×
5527
                },
×
5528
                {
×
5529
                        Column: "is_unique",
×
5530
                        Type:   BooleanType,
×
5531
                },
×
5532
        }
×
5533

×
5534
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
5535
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
5536
        if err != nil {
×
5537
                return nil, err
×
5538
        }
×
5539

5540
        values := make([][]ValueExp, len(table.cols))
×
5541

×
5542
        for i, c := range table.cols {
×
5543
                index := "NO"
×
5544

×
5545
                indexed, err := table.IsIndexed(c.Name())
×
5546
                if err != nil {
×
5547
                        return nil, err
×
5548
                }
×
5549
                if indexed {
×
5550
                        index = "YES"
×
5551
                }
×
5552

5553
                if table.PrimaryIndex().IncludesCol(c.ID()) {
×
5554
                        index = "PRIMARY KEY"
×
5555
                }
×
5556

5557
                var unique bool
×
5558
                for _, index := range table.GetIndexesByColID(c.ID()) {
×
5559
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
5560
                                unique = true
×
5561
                                break
×
5562
                        }
5563
                }
5564

5565
                var maxLen string
×
5566

×
5567
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
5568
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
×
5569
                }
×
5570

5571
                values[i] = []ValueExp{
×
5572
                        &Varchar{val: c.colName},
×
5573
                        &Varchar{val: c.Type() + maxLen},
×
5574
                        &Bool{val: c.IsNullable()},
×
5575
                        &Varchar{val: index},
×
5576
                        &Bool{val: c.IsAutoIncremental()},
×
5577
                        &Bool{val: unique},
×
5578
                }
×
5579
        }
5580

5581
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
×
5582
}
5583

5584
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
7✔
5585
        if len(stmt.fnCall.params) > 0 {
7✔
5586
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
×
5587
        }
×
5588

5589
        cols := []ColDescriptor{
7✔
5590
                {
7✔
5591
                        Column: "name",
7✔
5592
                        Type:   VarcharType,
7✔
5593
                },
7✔
5594
                {
7✔
5595
                        Column: "permission",
7✔
5596
                        Type:   VarcharType,
7✔
5597
                },
7✔
5598
        }
7✔
5599

7✔
5600
        users, err := tx.ListUsers(ctx)
7✔
5601
        if err != nil {
7✔
5602
                return nil, err
×
5603
        }
×
5604

5605
        values := make([][]ValueExp, len(users))
7✔
5606
        for i, user := range users {
23✔
5607
                perm := user.Permission()
16✔
5608

16✔
5609
                values[i] = []ValueExp{
16✔
5610
                        &Varchar{val: user.Username()},
16✔
5611
                        &Varchar{val: perm},
16✔
5612
                }
16✔
5613
        }
16✔
5614
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
7✔
5615
}
5616

5617
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5618
        if len(stmt.fnCall.params) != 1 {
3✔
5619
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
×
5620
        }
×
5621

5622
        cols := []ColDescriptor{
3✔
5623
                {
3✔
5624
                        Column: "table",
3✔
5625
                        Type:   VarcharType,
3✔
5626
                },
3✔
5627
                {
3✔
5628
                        Column: "name",
3✔
5629
                        Type:   VarcharType,
3✔
5630
                },
3✔
5631
                {
3✔
5632
                        Column: "type",
3✔
5633
                        Type:   VarcharType,
3✔
5634
                },
3✔
5635
                {
3✔
5636
                        Column: "max_length",
3✔
5637
                        Type:   IntegerType,
3✔
5638
                },
3✔
5639
                {
3✔
5640
                        Column: "nullable",
3✔
5641
                        Type:   BooleanType,
3✔
5642
                },
3✔
5643
                {
3✔
5644
                        Column: "auto_increment",
3✔
5645
                        Type:   BooleanType,
3✔
5646
                },
3✔
5647
                {
3✔
5648
                        Column: "indexed",
3✔
5649
                        Type:   BooleanType,
3✔
5650
                },
3✔
5651
                {
3✔
5652
                        Column: "primary",
3✔
5653
                        Type:   BooleanType,
3✔
5654
                },
3✔
5655
                {
3✔
5656
                        Column: "unique",
3✔
5657
                        Type:   BooleanType,
3✔
5658
                },
3✔
5659
        }
3✔
5660

3✔
5661
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5662
        if err != nil {
3✔
5663
                return nil, err
×
5664
        }
×
5665

5666
        tableName, err := val.reduce(tx, nil, "")
3✔
5667
        if err != nil {
3✔
5668
                return nil, err
×
5669
        }
×
5670

5671
        if tableName.Type() != VarcharType {
3✔
5672
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5673
        }
×
5674

5675
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5676
        if err != nil {
3✔
5677
                return nil, err
×
5678
        }
×
5679

5680
        values := make([][]ValueExp, len(table.cols))
3✔
5681

3✔
5682
        for i, c := range table.cols {
11✔
5683
                indexed, err := table.IsIndexed(c.Name())
8✔
5684
                if err != nil {
8✔
5685
                        return nil, err
×
5686
                }
×
5687

5688
                var unique bool
8✔
5689
                for _, index := range table.indexesByColID[c.id] {
16✔
5690
                        if index.IsUnique() && len(index.Cols()) == 1 {
11✔
5691
                                unique = true
3✔
5692
                                break
3✔
5693
                        }
5694
                }
5695

5696
                values[i] = []ValueExp{
8✔
5697
                        &Varchar{val: table.name},
8✔
5698
                        &Varchar{val: c.colName},
8✔
5699
                        &Varchar{val: c.colType},
8✔
5700
                        &Integer{val: int64(c.MaxLen())},
8✔
5701
                        &Bool{val: c.IsNullable()},
8✔
5702
                        &Bool{val: c.autoIncrement},
8✔
5703
                        &Bool{val: indexed},
8✔
5704
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
8✔
5705
                        &Bool{val: unique},
8✔
5706
                }
8✔
5707
        }
5708

5709
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5710
}
5711

5712
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5713
        if len(stmt.fnCall.params) != 1 {
3✔
5714
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
×
5715
        }
×
5716

5717
        cols := []ColDescriptor{
3✔
5718
                {
3✔
5719
                        Column: "table",
3✔
5720
                        Type:   VarcharType,
3✔
5721
                },
3✔
5722
                {
3✔
5723
                        Column: "name",
3✔
5724
                        Type:   VarcharType,
3✔
5725
                },
3✔
5726
                {
3✔
5727
                        Column: "unique",
3✔
5728
                        Type:   BooleanType,
3✔
5729
                },
3✔
5730
                {
3✔
5731
                        Column: "primary",
3✔
5732
                        Type:   BooleanType,
3✔
5733
                },
3✔
5734
        }
3✔
5735

3✔
5736
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5737
        if err != nil {
3✔
5738
                return nil, err
×
5739
        }
×
5740

5741
        tableName, err := val.reduce(tx, nil, "")
3✔
5742
        if err != nil {
3✔
5743
                return nil, err
×
5744
        }
×
5745

5746
        if tableName.Type() != VarcharType {
3✔
5747
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5748
        }
×
5749

5750
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5751
        if err != nil {
3✔
5752
                return nil, err
×
5753
        }
×
5754

5755
        values := make([][]ValueExp, len(table.indexes))
3✔
5756

3✔
5757
        for i, index := range table.indexes {
10✔
5758
                values[i] = []ValueExp{
7✔
5759
                        &Varchar{val: table.name},
7✔
5760
                        &Varchar{val: index.Name()},
7✔
5761
                        &Bool{val: index.unique},
7✔
5762
                        &Bool{val: index.IsPrimary()},
7✔
5763
                }
7✔
5764
        }
7✔
5765

5766
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5767
}
5768

5769
func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
2✔
5770
        if len(stmt.fnCall.params) > 1 {
2✔
5771
                return nil, fmt.Errorf("%w: function '%s' expect at most one parameter of type %s", ErrIllegalArguments, GrantsFnCall, VarcharType)
×
5772
        }
×
5773

5774
        var username string
2✔
5775
        if len(stmt.fnCall.params) == 1 {
3✔
5776
                val, err := stmt.fnCall.params[0].substitute(params)
1✔
5777
                if err != nil {
1✔
5778
                        return nil, err
×
5779
                }
×
5780

5781
                userVal, err := val.reduce(tx, nil, "")
1✔
5782
                if err != nil {
1✔
5783
                        return nil, err
×
5784
                }
×
5785

5786
                if userVal.Type() != VarcharType {
1✔
5787
                        return nil, fmt.Errorf("%w: expected '%s' for username but type '%s' given instead", ErrIllegalArguments, VarcharType, userVal.Type())
×
5788
                }
×
5789
                username, _ = userVal.RawValue().(string)
1✔
5790
        }
5791

5792
        cols := []ColDescriptor{
2✔
5793
                {
2✔
5794
                        Column: "user",
2✔
5795
                        Type:   VarcharType,
2✔
5796
                },
2✔
5797
                {
2✔
5798
                        Column: "privilege",
2✔
5799
                        Type:   VarcharType,
2✔
5800
                },
2✔
5801
        }
2✔
5802

2✔
5803
        var err error
2✔
5804
        var users []User
2✔
5805

2✔
5806
        if tx.engine.multidbHandler == nil {
2✔
5807
                return nil, ErrUnspecifiedMultiDBHandler
×
5808
        } else {
2✔
5809
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
2✔
5810
                if err != nil {
2✔
5811
                        return nil, err
×
5812
                }
×
5813
        }
5814

5815
        values := make([][]ValueExp, 0, len(users))
2✔
5816

2✔
5817
        for _, user := range users {
4✔
5818
                if username == "" || user.Username() == username {
4✔
5819
                        for _, p := range user.SQLPrivileges() {
6✔
5820
                                values = append(values, []ValueExp{
4✔
5821
                                        &Varchar{val: user.Username()},
4✔
5822
                                        &Varchar{val: string(p)},
4✔
5823
                                })
4✔
5824
                        }
4✔
5825
                }
5826
        }
5827

5828
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
2✔
5829
}
5830

5831
// DropTableStmt represents a statement to delete a table.
5832
type DropTableStmt struct {
5833
        table string
5834
}
5835

5836
func NewDropTableStmt(table string) *DropTableStmt {
6✔
5837
        return &DropTableStmt{table: table}
6✔
5838
}
6✔
5839

5840
func (stmt *DropTableStmt) readOnly() bool {
1✔
5841
        return false
1✔
5842
}
1✔
5843

5844
func (stmt *DropTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
5845
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5846
}
1✔
5847

5848
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5849
        return nil
1✔
5850
}
1✔
5851

5852
/*
5853
Exec executes the delete table statement.
5854
It the table exists, if not it does nothing.
5855
If the table exists, it deletes all the indexes and the table itself.
5856
Note that this is a soft delete of the index and table key,
5857
the data is not deleted, but the metadata is updated.
5858
*/
5859
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
7✔
5860
        if !tx.catalog.ExistTable(stmt.table) {
8✔
5861
                return nil, ErrTableDoesNotExist
1✔
5862
        }
1✔
5863

5864
        table, err := tx.catalog.GetTableByName(stmt.table)
6✔
5865
        if err != nil {
6✔
5866
                return nil, err
×
5867
        }
×
5868

5869
        // delete table
5870
        mappedKey := MapKey(
6✔
5871
                tx.sqlPrefix(),
6✔
5872
                catalogTablePrefix,
6✔
5873
                EncodeID(DatabaseID),
6✔
5874
                EncodeID(table.id),
6✔
5875
        )
6✔
5876
        err = tx.delete(ctx, mappedKey)
6✔
5877
        if err != nil {
6✔
5878
                return nil, err
×
5879
        }
×
5880

5881
        // delete columns
5882
        cols := table.ColumnsByID()
6✔
5883
        for _, col := range cols {
26✔
5884
                mappedKey := MapKey(
20✔
5885
                        tx.sqlPrefix(),
20✔
5886
                        catalogColumnPrefix,
20✔
5887
                        EncodeID(DatabaseID),
20✔
5888
                        EncodeID(col.table.id),
20✔
5889
                        EncodeID(col.id),
20✔
5890
                        []byte(col.colType),
20✔
5891
                )
20✔
5892
                err = tx.delete(ctx, mappedKey)
20✔
5893
                if err != nil {
20✔
5894
                        return nil, err
×
5895
                }
×
5896
        }
5897

5898
        // delete checks
5899
        for name := range table.checkConstraints {
6✔
5900
                key := MapKey(
×
5901
                        tx.sqlPrefix(),
×
5902
                        catalogCheckPrefix,
×
5903
                        EncodeID(DatabaseID),
×
5904
                        EncodeID(table.id),
×
5905
                        []byte(name),
×
5906
                )
×
5907

×
5908
                if err := tx.delete(ctx, key); err != nil {
×
5909
                        return nil, err
×
5910
                }
×
5911
        }
5912

5913
        // delete indexes
5914
        for _, index := range table.indexes {
13✔
5915
                mappedKey := MapKey(
7✔
5916
                        tx.sqlPrefix(),
7✔
5917
                        catalogIndexPrefix,
7✔
5918
                        EncodeID(DatabaseID),
7✔
5919
                        EncodeID(table.id),
7✔
5920
                        EncodeID(index.id),
7✔
5921
                )
7✔
5922
                err = tx.delete(ctx, mappedKey)
7✔
5923
                if err != nil {
7✔
5924
                        return nil, err
×
5925
                }
×
5926

5927
                indexKey := MapKey(
7✔
5928
                        tx.sqlPrefix(),
7✔
5929
                        MappedPrefix,
7✔
5930
                        EncodeID(table.id),
7✔
5931
                        EncodeID(index.id),
7✔
5932
                )
7✔
5933
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
14✔
5934
                        return sqlTx.engine.store.DeleteIndex(indexKey)
7✔
5935
                })
7✔
5936
                if err != nil {
7✔
5937
                        return nil, err
×
5938
                }
×
5939
        }
5940

5941
        err = tx.catalog.deleteTable(table)
6✔
5942
        if err != nil {
6✔
5943
                return nil, err
×
5944
        }
×
5945

5946
        tx.mutatedCatalog = true
6✔
5947

6✔
5948
        return tx, nil
6✔
5949
}
5950

5951
// DropIndexStmt represents a statement to delete a table.
5952
type DropIndexStmt struct {
5953
        table string
5954
        cols  []string
5955
}
5956

5957
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
4✔
5958
        return &DropIndexStmt{table: table, cols: cols}
4✔
5959
}
4✔
5960

5961
func (stmt *DropIndexStmt) readOnly() bool {
1✔
5962
        return false
1✔
5963
}
1✔
5964

5965
func (stmt *DropIndexStmt) requiredPrivileges() []SQLPrivilege {
1✔
5966
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5967
}
1✔
5968

5969
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5970
        return nil
1✔
5971
}
1✔
5972

5973
/*
5974
Exec executes the delete index statement.
5975
If the index exists, it deletes it. Note that this is a soft delete of the index
5976
the data is not deleted, but the metadata is updated.
5977
*/
5978
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
5979
        if !tx.catalog.ExistTable(stmt.table) {
7✔
5980
                return nil, ErrTableDoesNotExist
1✔
5981
        }
1✔
5982

5983
        table, err := tx.catalog.GetTableByName(stmt.table)
5✔
5984
        if err != nil {
5✔
5985
                return nil, err
×
5986
        }
×
5987

5988
        cols := make([]*Column, len(stmt.cols))
5✔
5989

5✔
5990
        for i, colName := range stmt.cols {
10✔
5991
                col, err := table.GetColumnByName(colName)
5✔
5992
                if err != nil {
5✔
5993
                        return nil, err
×
5994
                }
×
5995

5996
                cols[i] = col
5✔
5997
        }
5998

5999
        index, err := table.GetIndexByName(indexName(table.name, cols))
5✔
6000
        if err != nil {
5✔
6001
                return nil, err
×
6002
        }
×
6003

6004
        // delete index
6005
        mappedKey := MapKey(
5✔
6006
                tx.sqlPrefix(),
5✔
6007
                catalogIndexPrefix,
5✔
6008
                EncodeID(DatabaseID),
5✔
6009
                EncodeID(table.id),
5✔
6010
                EncodeID(index.id),
5✔
6011
        )
5✔
6012
        err = tx.delete(ctx, mappedKey)
5✔
6013
        if err != nil {
5✔
6014
                return nil, err
×
6015
        }
×
6016

6017
        indexKey := MapKey(
5✔
6018
                tx.sqlPrefix(),
5✔
6019
                MappedPrefix,
5✔
6020
                EncodeID(table.id),
5✔
6021
                EncodeID(index.id),
5✔
6022
        )
5✔
6023

5✔
6024
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
9✔
6025
                return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
6026
        })
4✔
6027
        if err != nil {
5✔
6028
                return nil, err
×
6029
        }
×
6030

6031
        err = table.deleteIndex(index)
5✔
6032
        if err != nil {
6✔
6033
                return nil, err
1✔
6034
        }
1✔
6035

6036
        tx.mutatedCatalog = true
4✔
6037

4✔
6038
        return tx, nil
4✔
6039
}
6040

6041
type SQLPrivilege string
6042

6043
const (
6044
        SQLPrivilegeSelect SQLPrivilege = "SELECT"
6045
        SQLPrivilegeCreate SQLPrivilege = "CREATE"
6046
        SQLPrivilegeInsert SQLPrivilege = "INSERT"
6047
        SQLPrivilegeUpdate SQLPrivilege = "UPDATE"
6048
        SQLPrivilegeDelete SQLPrivilege = "DELETE"
6049
        SQLPrivilegeDrop   SQLPrivilege = "DROP"
6050
        SQLPrivilegeAlter  SQLPrivilege = "ALTER"
6051
)
6052

6053
var allPrivileges = []SQLPrivilege{
6054
        SQLPrivilegeSelect,
6055
        SQLPrivilegeCreate,
6056
        SQLPrivilegeInsert,
6057
        SQLPrivilegeUpdate,
6058
        SQLPrivilegeDelete,
6059
        SQLPrivilegeDrop,
6060
        SQLPrivilegeAlter,
6061
}
6062

6063
func DefaultSQLPrivilegesForPermission(p Permission) []SQLPrivilege {
295✔
6064
        switch p {
295✔
6065
        case PermissionSysAdmin, PermissionAdmin, PermissionReadWrite:
284✔
6066
                return allPrivileges
284✔
6067
        case PermissionReadOnly:
11✔
6068
                return []SQLPrivilege{SQLPrivilegeSelect}
11✔
6069
        }
6070
        return nil
×
6071
}
6072

6073
type AlterPrivilegesStmt struct {
6074
        database   string
6075
        user       string
6076
        privileges []SQLPrivilege
6077
        isGrant    bool
6078
}
6079

6080
func (stmt *AlterPrivilegesStmt) readOnly() bool {
2✔
6081
        return false
2✔
6082
}
2✔
6083

6084
func (stmt *AlterPrivilegesStmt) requiredPrivileges() []SQLPrivilege {
2✔
6085
        return nil
2✔
6086
}
2✔
6087

6088
func (stmt *AlterPrivilegesStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2✔
6089
        if tx.IsExplicitCloseRequired() {
3✔
6090
                return nil, fmt.Errorf("%w: user privileges modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
6091
        }
1✔
6092

6093
        if tx.engine.multidbHandler == nil {
1✔
6094
                return nil, ErrUnspecifiedMultiDBHandler
×
6095
        }
×
6096

6097
        var err error
1✔
6098
        if stmt.isGrant {
1✔
6099
                err = tx.engine.multidbHandler.GrantSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
×
6100
        } else {
1✔
6101
                err = tx.engine.multidbHandler.RevokeSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
1✔
6102
        }
1✔
6103
        return nil, err
1✔
6104
}
6105

6106
func (stmt *AlterPrivilegesStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
6107
        return nil
1✔
6108
}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc