• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 18243888765

04 Oct 2025 11:45AM UTC coverage: 89.221% (-0.02%) from 89.241%
18243888765

Pull #2073

gh-ci

ostafen
chore(embedded/sql): Implement EXTRACT FROM TIMESTAMP expressions

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>
Pull Request #2073: chore(embedded/sql): Implement EXTRACT FROM TIMESTAMP expressions

357 of 402 new or added lines in 3 files covered. (88.81%)

1 existing line in 1 file now uncovered.

37909 of 42489 relevant lines covered (89.22%)

149222.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.97
/embedded/sql/stmt.go
1
/*
2
Copyright 2025 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix          = "CTL."
38
        catalogTablePrefix     = "CTL.TABLE."     // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix    = "CTL.COLUMN."    // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogIndexPrefix     = "CTL.INDEX."     // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
41
        catalogCheckPrefix     = "CTL.CHECK."     // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
42
        catalogPrivilegePrefix = "CTL.PRIVILEGE." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
43

44
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
45
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
46
)
47

48
const (
49
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
50
        PKIndexID  = uint32(0)
51
)
52

53
const (
54
        nullableFlag      byte = 1 << iota
55
        autoIncrementFlag byte = 1 << iota
56
)
57

58
const (
59
        revCol        = "_rev"
60
        txMetadataCol = "_tx_metadata"
61
)
62

63
var reservedColumns = map[string]struct{}{
64
        revCol:        {},
65
        txMetadataCol: {},
66
}
67

68
func isReservedCol(col string) bool {
15,676✔
69
        _, ok := reservedColumns[col]
15,676✔
70
        return ok
15,676✔
71
}
15,676✔
72

73
type SQLValueType = string
74

75
const (
76
        IntegerType   SQLValueType = "INTEGER"
77
        BooleanType   SQLValueType = "BOOLEAN"
78
        VarcharType   SQLValueType = "VARCHAR"
79
        UUIDType      SQLValueType = "UUID"
80
        BLOBType      SQLValueType = "BLOB"
81
        Float64Type   SQLValueType = "FLOAT"
82
        TimestampType SQLValueType = "TIMESTAMP"
83
        AnyType       SQLValueType = "ANY"
84
        JSONType      SQLValueType = "JSON"
85
)
86

87
func IsNumericType(t SQLValueType) bool {
217✔
88
        return t == IntegerType || t == Float64Type
217✔
89
}
217✔
90

91
type Permission = string
92

93
const (
94
        PermissionReadOnly  Permission = "READ"
95
        PermissionReadWrite Permission = "READWRITE"
96
        PermissionAdmin     Permission = "ADMIN"
97
        PermissionSysAdmin  Permission = "SYSADMIN"
98
)
99

100
func PermissionFromCode(code uint32) Permission {
629✔
101
        switch code {
629✔
102
        case 1:
16✔
103
                {
32✔
104
                        return PermissionReadOnly
16✔
105
                }
16✔
106
        case 2:
10✔
107
                {
20✔
108
                        return PermissionReadWrite
10✔
109
                }
10✔
110
        case 254:
20✔
111
                {
40✔
112
                        return PermissionAdmin
20✔
113
                }
20✔
114
        }
115
        return PermissionSysAdmin
583✔
116
}
117

118
type AggregateFn = string
119

120
const (
121
        COUNT AggregateFn = "COUNT"
122
        SUM   AggregateFn = "SUM"
123
        MAX   AggregateFn = "MAX"
124
        MIN   AggregateFn = "MIN"
125
        AVG   AggregateFn = "AVG"
126
)
127

128
type CmpOperator = int
129

130
const (
131
        EQ CmpOperator = iota
132
        NE
133
        LT
134
        LE
135
        GT
136
        GE
137
)
138

139
func CmpOperatorToString(op CmpOperator) string {
20✔
140
        switch op {
20✔
141
        case EQ:
8✔
142
                return "="
8✔
143
        case NE:
2✔
144
                return "!="
2✔
145
        case LT:
1✔
146
                return "<"
1✔
147
        case LE:
4✔
148
                return "<="
4✔
149
        case GT:
1✔
150
                return ">"
1✔
151
        case GE:
4✔
152
                return ">="
4✔
153
        }
154
        return ""
×
155
}
156

157
type LogicOperator = int
158

159
const (
160
        And LogicOperator = iota
161
        Or
162
)
163

164
func LogicOperatorToString(op LogicOperator) string {
31✔
165
        if op == And {
46✔
166
                return "AND"
15✔
167
        }
15✔
168
        return "OR"
16✔
169
}
170

171
type NumOperator = int
172

173
const (
174
        ADDOP NumOperator = iota
175
        SUBSOP
176
        DIVOP
177
        MULTOP
178
        MODOP
179
)
180

181
func NumOperatorString(op NumOperator) string {
18✔
182
        switch op {
18✔
183
        case ADDOP:
6✔
184
                return "+"
6✔
185
        case SUBSOP:
1✔
186
                return "-"
1✔
187
        case DIVOP:
5✔
188
                return "/"
5✔
189
        case MULTOP:
5✔
190
                return "*"
5✔
191
        case MODOP:
1✔
192
                return "%"
1✔
193
        }
194
        return ""
×
195
}
196

197
type JoinType = int
198

199
const (
200
        InnerJoin JoinType = iota
201
        LeftJoin
202
        RightJoin
203
)
204

205
type SQLStmt interface {
206
        readOnly() bool
207
        requiredPrivileges() []SQLPrivilege
208
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
209
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
210
}
211

212
type BeginTransactionStmt struct {
213
}
214

215
func (stmt *BeginTransactionStmt) readOnly() bool {
9✔
216
        return true
9✔
217
}
9✔
218

219
func (stmt *BeginTransactionStmt) requiredPrivileges() []SQLPrivilege {
9✔
220
        return nil
9✔
221
}
9✔
222

223
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
224
        return nil
3✔
225
}
3✔
226

227
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
61✔
228
        if tx.IsExplicitCloseRequired() {
62✔
229
                return nil, ErrNestedTxNotSupported
1✔
230
        }
1✔
231

232
        err := tx.RequireExplicitClose()
60✔
233
        if err == nil {
119✔
234
                // current tx can be reused as no changes were already made
59✔
235
                return tx, nil
59✔
236
        }
59✔
237

238
        // commit current transaction and start a fresh one
239

240
        err = tx.Commit(ctx)
1✔
241
        if err != nil {
1✔
242
                return nil, err
×
243
        }
×
244

245
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
246
}
247

248
type CommitStmt struct {
249
}
250

251
func (stmt *CommitStmt) readOnly() bool {
113✔
252
        return true
113✔
253
}
113✔
254

255
func (stmt *CommitStmt) requiredPrivileges() []SQLPrivilege {
113✔
256
        return nil
113✔
257
}
113✔
258

259
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
260
        return nil
1✔
261
}
1✔
262

263
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
264
        if !tx.IsExplicitCloseRequired() {
160✔
265
                return nil, ErrNoOngoingTx
1✔
266
        }
1✔
267

268
        return nil, tx.Commit(ctx)
158✔
269
}
270

271
type RollbackStmt struct {
272
}
273

274
func (stmt *RollbackStmt) readOnly() bool {
1✔
275
        return true
1✔
276
}
1✔
277

278
func (stmt *RollbackStmt) requiredPrivileges() []SQLPrivilege {
1✔
279
        return nil
1✔
280
}
1✔
281

282
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
283
        return nil
1✔
284
}
1✔
285

286
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
287
        if !tx.IsExplicitCloseRequired() {
5✔
288
                return nil, ErrNoOngoingTx
1✔
289
        }
1✔
290

291
        return nil, tx.Cancel()
3✔
292
}
293

294
type CreateDatabaseStmt struct {
295
        DB          string
296
        ifNotExists bool
297
}
298

299
func (stmt *CreateDatabaseStmt) readOnly() bool {
14✔
300
        return false
14✔
301
}
14✔
302

303
func (stmt *CreateDatabaseStmt) requiredPrivileges() []SQLPrivilege {
14✔
304
        return []SQLPrivilege{SQLPrivilegeCreate}
14✔
305
}
14✔
306

307
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
308
        return nil
4✔
309
}
4✔
310

311
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
312
        if tx.IsExplicitCloseRequired() {
16✔
313
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
314
        }
1✔
315

316
        if tx.engine.multidbHandler == nil {
16✔
317
                return nil, ErrUnspecifiedMultiDBHandler
2✔
318
        }
2✔
319

320
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
321
}
322

323
type UseDatabaseStmt struct {
324
        DB string
325
}
326

327
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
328
        return nil
1✔
329
}
1✔
330

331
func (stmt *UseDatabaseStmt) readOnly() bool {
9✔
332
        return true
9✔
333
}
9✔
334

335
func (stmt *UseDatabaseStmt) requiredPrivileges() []SQLPrivilege {
9✔
336
        return nil
9✔
337
}
9✔
338

339
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
340
        if tx.IsExplicitCloseRequired() {
10✔
341
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
342
        }
1✔
343

344
        if tx.engine.multidbHandler == nil {
9✔
345
                return nil, ErrUnspecifiedMultiDBHandler
1✔
346
        }
1✔
347

348
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
349
}
350

351
type UseSnapshotStmt struct {
352
        period period
353
}
354

355
func (stmt *UseSnapshotStmt) readOnly() bool {
1✔
356
        return true
1✔
357
}
1✔
358

359
func (stmt *UseSnapshotStmt) requiredPrivileges() []SQLPrivilege {
1✔
360
        return nil
1✔
361
}
1✔
362

363
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
364
        return nil
1✔
365
}
1✔
366

367
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
368
        return nil, ErrNoSupported
1✔
369
}
1✔
370

371
type CreateUserStmt struct {
372
        username   string
373
        password   string
374
        permission Permission
375
}
376

377
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
378
        return nil
1✔
379
}
1✔
380

381
func (stmt *CreateUserStmt) readOnly() bool {
5✔
382
        return false
5✔
383
}
5✔
384

385
func (stmt *CreateUserStmt) requiredPrivileges() []SQLPrivilege {
5✔
386
        return []SQLPrivilege{SQLPrivilegeCreate}
5✔
387
}
5✔
388

389
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
390
        if tx.IsExplicitCloseRequired() {
7✔
391
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
392
        }
1✔
393

394
        if tx.engine.multidbHandler == nil {
6✔
395
                return nil, ErrUnspecifiedMultiDBHandler
1✔
396
        }
1✔
397

398
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
399
}
400

401
type AlterUserStmt struct {
402
        username   string
403
        password   string
404
        permission Permission
405
}
406

407
func (stmt *AlterUserStmt) readOnly() bool {
4✔
408
        return false
4✔
409
}
4✔
410

411
func (stmt *AlterUserStmt) requiredPrivileges() []SQLPrivilege {
4✔
412
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
413
}
4✔
414

415
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
416
        return nil
1✔
417
}
1✔
418

419
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
420
        if tx.IsExplicitCloseRequired() {
6✔
421
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
422
        }
1✔
423

424
        if tx.engine.multidbHandler == nil {
5✔
425
                return nil, ErrUnspecifiedMultiDBHandler
1✔
426
        }
1✔
427

428
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
429
}
430

431
type DropUserStmt struct {
432
        username string
433
}
434

435
func (stmt *DropUserStmt) readOnly() bool {
3✔
436
        return false
3✔
437
}
3✔
438

439
func (stmt *DropUserStmt) requiredPrivileges() []SQLPrivilege {
3✔
440
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
441
}
3✔
442

443
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
444
        return nil
1✔
445
}
1✔
446

447
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
448
        if tx.IsExplicitCloseRequired() {
4✔
449
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
450
        }
1✔
451

452
        if tx.engine.multidbHandler == nil {
3✔
453
                return nil, ErrUnspecifiedMultiDBHandler
1✔
454
        }
1✔
455

456
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
457
}
458

459
type TableElem interface{}
460

461
type CreateTableStmt struct {
462
        table       string
463
        ifNotExists bool
464
        colsSpec    []*ColSpec
465
        checks      []CheckConstraint
466
        pkColNames  PrimaryKeyConstraint
467
}
468

469
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
470
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
471
}
38✔
472

473
func (stmt *CreateTableStmt) readOnly() bool {
66✔
474
        return false
66✔
475
}
66✔
476

477
func (stmt *CreateTableStmt) requiredPrivileges() []SQLPrivilege {
64✔
478
        return []SQLPrivilege{SQLPrivilegeCreate}
64✔
479
}
64✔
480

481
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
482
        return nil
4✔
483
}
4✔
484

485
func zeroRow(tableName string, cols []*ColSpec) *Row {
244✔
486
        r := Row{
244✔
487
                ValuesByPosition: make([]TypedValue, len(cols)),
244✔
488
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
244✔
489
        }
244✔
490

244✔
491
        for i, col := range cols {
1,024✔
492
                v := zeroForType(col.colType)
780✔
493

780✔
494
                r.ValuesByPosition[i] = v
780✔
495
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
780✔
496
        }
780✔
497
        return &r
244✔
498
}
499

500
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
235✔
501
        if err := stmt.validatePrimaryKey(); err != nil {
238✔
502
                return nil, err
3✔
503
        }
3✔
504

505
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
233✔
506
                return tx, nil
1✔
507
        }
1✔
508

509
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
231✔
510
        for i, cs := range stmt.colsSpec {
951✔
511
                colSpecs[uint32(i)+1] = cs
720✔
512
        }
720✔
513

514
        row := zeroRow(stmt.table, stmt.colsSpec)
231✔
515
        for _, check := range stmt.checks {
240✔
516
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
517
                if err != nil {
11✔
518
                        return nil, err
2✔
519
                }
2✔
520

521
                if value.Type() != BooleanType {
7✔
522
                        return nil, ErrInvalidCheckConstraint
×
523
                }
×
524
        }
525

526
        nextUnnamedCheck := 0
229✔
527
        checks := make(map[string]CheckConstraint)
229✔
528
        for id, check := range stmt.checks {
236✔
529
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
530
                if check.name != "" {
9✔
531
                        name = check.name
2✔
532
                } else {
7✔
533
                        nextUnnamedCheck++
5✔
534
                }
5✔
535
                check.id = uint32(id)
7✔
536
                check.name = name
7✔
537
                checks[name] = check
7✔
538
        }
539

540
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
229✔
541
        if err != nil {
235✔
542
                return nil, err
6✔
543
        }
6✔
544

545
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.primaryKeyCols()}
223✔
546
        _, err = createIndexStmt.execAt(ctx, tx, params)
223✔
547
        if err != nil {
228✔
548
                return nil, err
5✔
549
        }
5✔
550

551
        for _, col := range table.cols {
919✔
552
                if col.autoIncrement {
778✔
553
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
78✔
554
                                return nil, ErrLimitedAutoIncrement
1✔
555
                        }
1✔
556
                }
557

558
                err := persistColumn(tx, col)
700✔
559
                if err != nil {
700✔
560
                        return nil, err
×
561
                }
×
562
        }
563

564
        for _, check := range checks {
224✔
565
                if err := persistCheck(tx, table, &check); err != nil {
7✔
566
                        return nil, err
×
567
                }
×
568
        }
569

570
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
217✔
571

217✔
572
        err = tx.set(mappedKey, nil, []byte(table.name))
217✔
573
        if err != nil {
217✔
574
                return nil, err
×
575
        }
×
576

577
        tx.mutatedCatalog = true
217✔
578

217✔
579
        return tx, nil
217✔
580
}
581

582
func (stmt *CreateTableStmt) validatePrimaryKey() error {
235✔
583
        n := 0
235✔
584
        for _, spec := range stmt.colsSpec {
962✔
585
                if spec.primaryKey {
732✔
586
                        n++
5✔
587
                }
5✔
588
        }
589

590
        if len(stmt.pkColNames) > 0 {
466✔
591
                n++
231✔
592
        }
231✔
593

594
        switch n {
235✔
595
        case 0:
1✔
596
                return ErrNoPrimaryKey
1✔
597
        case 1:
232✔
598
                return nil
232✔
599
        }
600
        return fmt.Errorf("\"%s\": %w", stmt.table, ErrMultiplePrimaryKeys)
2✔
601
}
602

603
func (stmt *CreateTableStmt) primaryKeyCols() []string {
223✔
604
        if len(stmt.pkColNames) > 0 {
444✔
605
                return stmt.pkColNames
221✔
606
        }
221✔
607

608
        for _, spec := range stmt.colsSpec {
4✔
609
                if spec.primaryKey {
4✔
610
                        return []string{spec.colName}
2✔
611
                }
2✔
612
        }
613
        return nil
×
614
}
615

616
func persistColumn(tx *SQLTx, col *Column) error {
720✔
617
        //{auto_incremental | nullable}{maxLen}{colNAME})
720✔
618
        v := make([]byte, 1+4+len(col.colName))
720✔
619

720✔
620
        if col.autoIncrement {
796✔
621
                v[0] = v[0] | autoIncrementFlag
76✔
622
        }
76✔
623

624
        if col.notNull {
767✔
625
                v[0] = v[0] | nullableFlag
47✔
626
        }
47✔
627

628
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
720✔
629

720✔
630
        copy(v[5:], []byte(col.Name()))
720✔
631

720✔
632
        mappedKey := MapKey(
720✔
633
                tx.sqlPrefix(),
720✔
634
                catalogColumnPrefix,
720✔
635
                EncodeID(DatabaseID),
720✔
636
                EncodeID(col.table.id),
720✔
637
                EncodeID(col.id),
720✔
638
                []byte(col.colType),
720✔
639
        )
720✔
640

720✔
641
        return tx.set(mappedKey, nil, v)
720✔
642
}
643

644
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
645
        mappedKey := MapKey(
7✔
646
                tx.sqlPrefix(),
7✔
647
                catalogCheckPrefix,
7✔
648
                EncodeID(DatabaseID),
7✔
649
                EncodeID(table.id),
7✔
650
                EncodeID(check.id),
7✔
651
        )
7✔
652

7✔
653
        name := check.name
7✔
654
        expText := check.exp.String()
7✔
655

7✔
656
        val := make([]byte, 2+len(name)+len(expText))
7✔
657

7✔
658
        if len(name) > 256 {
7✔
659
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
660
        }
×
661

662
        val[0] = byte(len(name)) - 1
7✔
663

7✔
664
        copy(val[1:], []byte(name))
7✔
665
        copy(val[1+len(name):], []byte(expText))
7✔
666

7✔
667
        return tx.set(mappedKey, nil, val)
7✔
668
}
669

670
type ColSpec struct {
671
        colName       string
672
        colType       SQLValueType
673
        maxLen        int
674
        autoIncrement bool
675
        notNull       bool
676
        primaryKey    bool
677
}
678

679
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
680
        return &ColSpec{
188✔
681
                colName:       name,
188✔
682
                colType:       colType,
188✔
683
                maxLen:        maxLen,
188✔
684
                autoIncrement: autoIncrement,
188✔
685
                notNull:       notNull,
188✔
686
        }
188✔
687
}
188✔
688

689
type CreateIndexStmt struct {
690
        unique      bool
691
        ifNotExists bool
692
        table       string
693
        cols        []string
694
}
695

696
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
697
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
698
}
72✔
699

700
func (stmt *CreateIndexStmt) readOnly() bool {
7✔
701
        return false
7✔
702
}
7✔
703

704
func (stmt *CreateIndexStmt) requiredPrivileges() []SQLPrivilege {
7✔
705
        return []SQLPrivilege{SQLPrivilegeCreate}
7✔
706
}
7✔
707

708
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
709
        return nil
1✔
710
}
1✔
711

712
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
374✔
713
        if len(stmt.cols) < 1 {
375✔
714
                return nil, ErrIllegalArguments
1✔
715
        }
1✔
716

717
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
374✔
718
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
719
        }
1✔
720

721
        table, err := tx.catalog.GetTableByName(stmt.table)
372✔
722
        if err != nil {
374✔
723
                return nil, err
2✔
724
        }
2✔
725

726
        colIDs := make([]uint32, len(stmt.cols))
370✔
727

370✔
728
        indexKeyLen := 0
370✔
729

370✔
730
        for i, colName := range stmt.cols {
767✔
731
                col, err := table.GetColumnByName(colName)
397✔
732
                if err != nil {
402✔
733
                        return nil, err
5✔
734
                }
5✔
735

736
                if col.Type() == JSONType {
394✔
737
                        return nil, ErrCannotIndexJson
2✔
738
                }
2✔
739

740
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
392✔
741
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
742
                }
2✔
743

744
                indexKeyLen += col.MaxLen()
388✔
745

388✔
746
                colIDs[i] = col.id
388✔
747
        }
748

749
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
361✔
750
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
751
        }
×
752

753
        if stmt.unique && table.primaryIndex != nil {
381✔
754
                // check table is empty
20✔
755
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
756
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
757
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
758
                        return nil, ErrTableDoesNotExist
×
759
                }
×
760
                if err == nil {
21✔
761
                        return nil, ErrLimitedIndexCreation
1✔
762
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
763
                        return nil, err
×
764
                }
×
765
        }
766

767
        index, err := table.newIndex(stmt.unique, colIDs)
360✔
768
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
362✔
769
                return tx, nil
2✔
770
        }
2✔
771
        if err != nil {
362✔
772
                return nil, err
4✔
773
        }
4✔
774

775
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
776
        // TODO: currently only ASC order is supported
777
        colSpecLen := EncIDLen + 1
354✔
778

354✔
779
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
354✔
780

354✔
781
        if index.IsUnique() {
590✔
782
                encodedValues[0] = 1
236✔
783
        }
236✔
784

785
        for i, col := range index.cols {
735✔
786
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
381✔
787
        }
381✔
788

789
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
354✔
790

354✔
791
        err = tx.set(mappedKey, nil, encodedValues)
354✔
792
        if err != nil {
354✔
793
                return nil, err
×
794
        }
×
795

796
        tx.mutatedCatalog = true
354✔
797

354✔
798
        return tx, nil
354✔
799
}
800

801
type AddColumnStmt struct {
802
        table   string
803
        colSpec *ColSpec
804
}
805

806
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
807
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
808
}
6✔
809

810
func (stmt *AddColumnStmt) readOnly() bool {
4✔
811
        return false
4✔
812
}
4✔
813

814
func (stmt *AddColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
815
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
816
}
4✔
817

818
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
819
        return nil
1✔
820
}
1✔
821

822
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
823
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
824
        if err != nil {
20✔
825
                return nil, err
1✔
826
        }
1✔
827

828
        col, err := table.newColumn(stmt.colSpec)
18✔
829
        if err != nil {
24✔
830
                return nil, err
6✔
831
        }
6✔
832

833
        err = persistColumn(tx, col)
12✔
834
        if err != nil {
12✔
835
                return nil, err
×
836
        }
×
837

838
        tx.mutatedCatalog = true
12✔
839

12✔
840
        return tx, nil
12✔
841
}
842

843
type RenameTableStmt struct {
844
        oldName string
845
        newName string
846
}
847

848
func (stmt *RenameTableStmt) readOnly() bool {
1✔
849
        return false
1✔
850
}
1✔
851

852
func (stmt *RenameTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
853
        return []SQLPrivilege{SQLPrivilegeAlter}
1✔
854
}
1✔
855

856
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
857
        return nil
1✔
858
}
1✔
859

860
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
861
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
862
        if err != nil {
10✔
863
                return nil, err
4✔
864
        }
4✔
865

866
        // update table name
867
        mappedKey := MapKey(
2✔
868
                tx.sqlPrefix(),
2✔
869
                catalogTablePrefix,
2✔
870
                EncodeID(DatabaseID),
2✔
871
                EncodeID(table.id),
2✔
872
        )
2✔
873
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
874
        if err != nil {
2✔
875
                return nil, err
×
876
        }
×
877

878
        tx.mutatedCatalog = true
2✔
879

2✔
880
        return tx, nil
2✔
881
}
882

883
type RenameColumnStmt struct {
884
        table   string
885
        oldName string
886
        newName string
887
}
888

889
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
890
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
891
}
3✔
892

893
func (stmt *RenameColumnStmt) readOnly() bool {
4✔
894
        return false
4✔
895
}
4✔
896

897
func (stmt *RenameColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
898
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
899
}
4✔
900

901
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
902
        return nil
1✔
903
}
1✔
904

905
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
906
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
907
        if err != nil {
11✔
908
                return nil, err
1✔
909
        }
1✔
910

911
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
912
        if err != nil {
12✔
913
                return nil, err
3✔
914
        }
3✔
915

916
        err = persistColumn(tx, col)
6✔
917
        if err != nil {
6✔
918
                return nil, err
×
919
        }
×
920

921
        tx.mutatedCatalog = true
6✔
922

6✔
923
        return tx, nil
6✔
924
}
925

926
type DropColumnStmt struct {
927
        table   string
928
        colName string
929
}
930

931
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
932
        return &DropColumnStmt{table: table, colName: colName}
8✔
933
}
8✔
934

935
func (stmt *DropColumnStmt) readOnly() bool {
2✔
936
        return false
2✔
937
}
2✔
938

939
func (stmt *DropColumnStmt) requiredPrivileges() []SQLPrivilege {
2✔
940
        return []SQLPrivilege{SQLPrivilegeDrop}
2✔
941
}
2✔
942

943
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
944
        return nil
1✔
945
}
1✔
946

947
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
948
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
949
        if err != nil {
21✔
950
                return nil, err
2✔
951
        }
2✔
952

953
        col, err := table.GetColumnByName(stmt.colName)
17✔
954
        if err != nil {
21✔
955
                return nil, err
4✔
956
        }
4✔
957

958
        err = canDropColumn(tx, table, col)
13✔
959
        if err != nil {
14✔
960
                return nil, err
1✔
961
        }
1✔
962

963
        err = table.deleteColumn(col)
12✔
964
        if err != nil {
16✔
965
                return nil, err
4✔
966
        }
4✔
967

968
        err = persistColumnDeletion(ctx, tx, col)
8✔
969
        if err != nil {
8✔
970
                return nil, err
×
971
        }
×
972

973
        tx.mutatedCatalog = true
8✔
974

8✔
975
        return tx, nil
8✔
976
}
977

978
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
979
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
980
        for _, c := range table.cols {
86✔
981
                if c.id != col.id {
133✔
982
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
983
                }
60✔
984
        }
985

986
        row := zeroRow(table.Name(), colSpecs)
13✔
987
        for name, check := range table.checkConstraints {
19✔
988
                _, err := check.exp.reduce(tx, row, table.name)
6✔
989
                if errors.Is(err, ErrColumnDoesNotExist) {
7✔
990
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
991
                }
1✔
992

993
                if err != nil {
5✔
994
                        return err
×
995
                }
×
996
        }
997
        return nil
12✔
998
}
999

1000
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
1001
        mappedKey := MapKey(
9✔
1002
                tx.sqlPrefix(),
9✔
1003
                catalogColumnPrefix,
9✔
1004
                EncodeID(DatabaseID),
9✔
1005
                EncodeID(col.table.id),
9✔
1006
                EncodeID(col.id),
9✔
1007
                []byte(col.colType),
9✔
1008
        )
9✔
1009

9✔
1010
        return tx.delete(ctx, mappedKey)
9✔
1011
}
9✔
1012

1013
type DropConstraintStmt struct {
1014
        table          string
1015
        constraintName string
1016
}
1017

1018
func (stmt *DropConstraintStmt) readOnly() bool {
×
1019
        return false
×
1020
}
×
1021

1022
func (stmt *DropConstraintStmt) requiredPrivileges() []SQLPrivilege {
×
1023
        return []SQLPrivilege{SQLPrivilegeDrop}
×
1024
}
×
1025

1026
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
1027
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
1028
        if err != nil {
4✔
1029
                return nil, err
×
1030
        }
×
1031

1032
        id, err := table.deleteCheck(stmt.constraintName)
4✔
1033
        if err != nil {
5✔
1034
                return nil, err
1✔
1035
        }
1✔
1036

1037
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
1038

3✔
1039
        tx.mutatedCatalog = true
3✔
1040

3✔
1041
        return tx, err
3✔
1042
}
1043

1044
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
1045
        mappedKey := MapKey(
3✔
1046
                tx.sqlPrefix(),
3✔
1047
                catalogCheckPrefix,
3✔
1048
                EncodeID(DatabaseID),
3✔
1049
                EncodeID(tableID),
3✔
1050
                EncodeID(checkId),
3✔
1051
        )
3✔
1052
        return tx.delete(ctx, mappedKey)
3✔
1053
}
3✔
1054

1055
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
1056
        return nil
×
1057
}
×
1058

1059
type UpsertIntoStmt struct {
1060
        isInsert   bool
1061
        tableRef   *tableRef
1062
        cols       []string
1063
        ds         DataSource
1064
        onConflict *OnConflictDo
1065
}
1066

1067
func (stmt *UpsertIntoStmt) readOnly() bool {
101✔
1068
        return false
101✔
1069
}
101✔
1070

1071
func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege {
101✔
1072
        privileges := stmt.privileges()
101✔
1073
        if stmt.ds != nil {
200✔
1074
                privileges = append(privileges, stmt.ds.requiredPrivileges()...)
99✔
1075
        }
99✔
1076
        return privileges
101✔
1077
}
1078

1079
func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege {
101✔
1080
        if stmt.isInsert {
190✔
1081
                return []SQLPrivilege{SQLPrivilegeInsert}
89✔
1082
        }
89✔
1083
        return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate}
12✔
1084
}
1085

1086
func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
1087
        return &UpsertIntoStmt{
120✔
1088
                isInsert:   isInsert,
120✔
1089
                tableRef:   NewTableRef(table, ""),
120✔
1090
                cols:       cols,
120✔
1091
                ds:         ds,
120✔
1092
                onConflict: onConflict,
120✔
1093
        }
120✔
1094
}
120✔
1095

1096
type RowSpec struct {
1097
        Values []ValueExp
1098
}
1099

1100
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
1101
        return &RowSpec{
129✔
1102
                Values: values,
129✔
1103
        }
129✔
1104
}
129✔
1105

1106
type OnConflictDo struct{}
1107

1108
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
1109
        ds, ok := stmt.ds.(*valuesDataSource)
11✔
1110
        if !ok {
11✔
1111
                return stmt.ds.inferParameters(ctx, tx, params)
×
1112
        }
×
1113

1114
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
1115
        for _, row := range ds.rows {
23✔
1116
                if len(stmt.cols) != len(row.Values) {
13✔
1117
                        return ErrInvalidNumberOfValues
1✔
1118
                }
1✔
1119

1120
                for i, val := range row.Values {
36✔
1121
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
1122
                        if err != nil {
26✔
1123
                                return err
1✔
1124
                        }
1✔
1125

1126
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
1127
                        if err != nil {
25✔
1128
                                return err
1✔
1129
                        }
1✔
1130

1131
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
1132
                        if err != nil {
25✔
1133
                                return err
2✔
1134
                        }
2✔
1135
                }
1136
        }
1137
        return nil
6✔
1138
}
1139

1140
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
2,115✔
1141
        selPosByColID := make(map[uint32]int, len(stmt.cols))
2,115✔
1142

2,115✔
1143
        for i, c := range stmt.cols {
9,736✔
1144
                col, err := table.GetColumnByName(c)
7,621✔
1145
                if err != nil {
7,623✔
1146
                        return nil, err
2✔
1147
                }
2✔
1148

1149
                _, duplicated := selPosByColID[col.id]
7,619✔
1150
                if duplicated {
7,620✔
1151
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
1152
                }
1✔
1153

1154
                selPosByColID[col.id] = i
7,618✔
1155
        }
1156

1157
        return selPosByColID, nil
2,112✔
1158
}
1159

1160
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2,118✔
1161
        table, err := stmt.tableRef.referencedTable(tx)
2,118✔
1162
        if err != nil {
2,121✔
1163
                return nil, err
3✔
1164
        }
3✔
1165

1166
        selPosByColID, err := stmt.validate(table)
2,115✔
1167
        if err != nil {
2,118✔
1168
                return nil, err
3✔
1169
        }
3✔
1170

1171
        r := &Row{
2,112✔
1172
                ValuesByPosition: make([]TypedValue, len(table.cols)),
2,112✔
1173
                ValuesBySelector: make(map[string]TypedValue),
2,112✔
1174
        }
2,112✔
1175

2,112✔
1176
        reader, err := stmt.ds.Resolve(ctx, tx, params, nil)
2,112✔
1177
        if err != nil {
2,112✔
1178
                return nil, err
×
1179
        }
×
1180
        defer reader.Close()
2,112✔
1181

2,112✔
1182
        for {
6,490✔
1183
                row, err := reader.Read(ctx)
4,378✔
1184
                if errors.Is(err, ErrNoMoreRows) {
6,447✔
1185
                        break
2,069✔
1186
                }
1187
                if err != nil {
2,319✔
1188
                        return nil, err
10✔
1189
                }
10✔
1190

1191
                if len(row.ValuesByPosition) != len(stmt.cols) {
2,301✔
1192
                        return nil, ErrInvalidNumberOfValues
2✔
1193
                }
2✔
1194

1195
                valuesByColID := make(map[uint32]TypedValue)
2,297✔
1196

2,297✔
1197
                var pkMustExist bool
2,297✔
1198

2,297✔
1199
                for colID, col := range table.colsByID {
11,742✔
1200
                        colPos, specified := selPosByColID[colID]
9,445✔
1201
                        if !specified {
10,691✔
1202
                                // TODO: Default values
1,246✔
1203
                                if col.notNull && !col.autoIncrement {
1,247✔
1204
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
1205
                                }
1✔
1206

1207
                                // inject auto-incremental pk value
1208
                                if stmt.isInsert && col.autoIncrement {
2,303✔
1209
                                        // current implementation assumes only PK can be set as autoincremental
1,058✔
1210
                                        table.maxPK++
1,058✔
1211

1,058✔
1212
                                        pkCol := table.primaryIndex.cols[0]
1,058✔
1213
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
1,058✔
1214

1,058✔
1215
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,924✔
1216
                                                tx.firstInsertedPKs[table.name] = table.maxPK
866✔
1217
                                        }
866✔
1218
                                        tx.lastInsertedPKs[table.name] = table.maxPK
1,058✔
1219
                                }
1220

1221
                                continue
1,245✔
1222
                        }
1223

1224
                        // value was specified
1225
                        cVal := row.ValuesByPosition[colPos]
8,199✔
1226

8,199✔
1227
                        val, err := cVal.substitute(params)
8,199✔
1228
                        if err != nil {
8,199✔
1229
                                return nil, err
×
1230
                        }
×
1231

1232
                        rval, err := val.reduce(tx, nil, table.name)
8,199✔
1233
                        if err != nil {
8,199✔
1234
                                return nil, err
×
1235
                        }
×
1236

1237
                        if rval.IsNull() {
8,297✔
1238
                                if col.notNull || col.autoIncrement {
98✔
1239
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1240
                                }
×
1241

1242
                                continue
98✔
1243
                        }
1244

1245
                        if col.autoIncrement {
8,120✔
1246
                                // validate specified value
19✔
1247
                                nl, isNumber := rval.RawValue().(int64)
19✔
1248
                                if !isNumber {
19✔
1249
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1250
                                }
×
1251

1252
                                pkMustExist = nl <= table.maxPK
19✔
1253

19✔
1254
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1255
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1256
                                }
19✔
1257
                                tx.lastInsertedPKs[table.name] = nl
19✔
1258
                        }
1259

1260
                        valuesByColID[colID] = rval
8,101✔
1261
                }
1262

1263
                for i, col := range table.cols {
11,739✔
1264
                        v := valuesByColID[col.id]
9,443✔
1265

9,443✔
1266
                        if v == nil {
9,727✔
1267
                                v = NewNull(AnyType)
284✔
1268
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
9,448✔
1269
                                s, _ := v.RawValue().(string)
5✔
1270
                                jsonVal, err := NewJsonFromString(s)
5✔
1271
                                if err != nil {
5✔
1272
                                        return nil, err
×
1273
                                }
×
1274
                                v = jsonVal
5✔
1275
                        }
1276

1277
                        r.ValuesByPosition[i] = v
9,443✔
1278
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
9,443✔
1279
                }
1280

1281
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
2,302✔
1282
                        return nil, err
6✔
1283
                }
6✔
1284

1285
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
2,290✔
1286
                if err != nil {
2,295✔
1287
                        return nil, err
5✔
1288
                }
5✔
1289

1290
                // pk entry
1291
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
2,285✔
1292
                if len(mappedPKey) > MaxKeyLen {
2,285✔
1293
                        return nil, ErrMaxKeyLengthExceeded
×
1294
                }
×
1295

1296
                _, err = tx.get(ctx, mappedPKey)
2,285✔
1297
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
2,285✔
1298
                        return nil, err
×
1299
                }
×
1300

1301
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
2,287✔
1302
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1303
                }
2✔
1304

1305
                if stmt.isInsert {
4,385✔
1306
                        if err == nil && stmt.onConflict == nil {
2,106✔
1307
                                return nil, store.ErrKeyAlreadyExists
4✔
1308
                        }
4✔
1309

1310
                        if err == nil && stmt.onConflict != nil {
2,101✔
1311
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1312
                                continue
3✔
1313
                        }
1314
                }
1315

1316
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
2,276✔
1317
                if err != nil {
2,289✔
1318
                        return nil, err
13✔
1319
                }
13✔
1320
        }
1321
        return tx, nil
2,069✔
1322
}
1323

1324
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
2,329✔
1325
        for _, check := range checks {
2,371✔
1326
                val, err := check.exp.reduce(tx, row, table)
42✔
1327
                if err != nil {
43✔
1328
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1329
                }
1✔
1330

1331
                if val.Type() != BooleanType {
41✔
1332
                        return ErrInvalidCheckConstraint
×
1333
                }
×
1334

1335
                if !val.RawValue().(bool) {
48✔
1336
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1337
                }
7✔
1338
        }
1339
        return nil
2,321✔
1340
}
1341

1342
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,464✔
1343
        valbuf := bytes.Buffer{}
2,464✔
1344

2,464✔
1345
        // null values are not serialized
2,464✔
1346
        encodedVals := 0
2,464✔
1347
        for _, v := range valuesByColID {
12,042✔
1348
                if !v.IsNull() {
19,138✔
1349
                        encodedVals++
9,560✔
1350
                }
9,560✔
1351
        }
1352

1353
        b := make([]byte, EncLenLen)
2,464✔
1354
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,464✔
1355

2,464✔
1356
        _, err := valbuf.Write(b)
2,464✔
1357
        if err != nil {
2,464✔
1358
                return nil, err
×
1359
        }
×
1360

1361
        for _, col := range table.cols {
12,302✔
1362
                rval, specified := valuesByColID[col.id]
9,838✔
1363
                if !specified || rval.IsNull() {
10,122✔
1364
                        continue
284✔
1365
                }
1366

1367
                b := make([]byte, EncIDLen)
9,554✔
1368
                binary.BigEndian.PutUint32(b, uint32(col.id))
9,554✔
1369

9,554✔
1370
                _, err = valbuf.Write(b)
9,554✔
1371
                if err != nil {
9,554✔
1372
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1373
                }
×
1374

1375
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
9,554✔
1376
                if err != nil {
9,562✔
1377
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1378
                }
8✔
1379

1380
                _, err = valbuf.Write(encVal)
9,546✔
1381
                if err != nil {
9,546✔
1382
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1383
                }
×
1384
        }
1385

1386
        return valbuf.Bytes(), nil
2,456✔
1387
}
1388

1389
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
2,307✔
1390
        var reusableIndexEntries map[uint32]struct{}
2,307✔
1391

2,307✔
1392
        if reuseIndex && len(table.indexes) > 1 {
2,364✔
1393
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1394
                if err == nil {
93✔
1395
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1396

36✔
1397
                        for _, col := range table.cols {
161✔
1398
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1399
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1400
                        }
125✔
1401

1402
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1403
                        if err != nil {
36✔
1404
                                return err
×
1405
                        }
×
1406
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1407
                        return err
×
1408
                }
×
1409
        }
1410

1411
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
2,307✔
1412

2,307✔
1413
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
2,307✔
1414
        if err != nil {
2,315✔
1415
                return err
8✔
1416
        }
8✔
1417

1418
        err = tx.set(rowKey, nil, encodedRowValue)
2,299✔
1419
        if err != nil {
2,299✔
1420
                return err
×
1421
        }
×
1422

1423
        // create in-memory and validate entries for secondary indexes
1424
        for _, index := range table.indexes {
5,493✔
1425
                if index.IsPrimary() {
5,493✔
1426
                        continue
2,299✔
1427
                }
1428

1429
                if reusableIndexEntries != nil {
972✔
1430
                        _, reusable := reusableIndexEntries[index.id]
77✔
1431
                        if reusable {
127✔
1432
                                continue
50✔
1433
                        }
1434
                }
1435

1436
                encodedValues := make([][]byte, 2+len(index.cols))
845✔
1437
                encodedValues[0] = EncodeID(table.id)
845✔
1438
                encodedValues[1] = EncodeID(index.id)
845✔
1439

845✔
1440
                indexKeyLen := 0
845✔
1441

845✔
1442
                for i, col := range index.cols {
1,765✔
1443
                        rval, specified := valuesByColID[col.id]
920✔
1444
                        if !specified {
993✔
1445
                                rval = &NullValue{t: col.colType}
73✔
1446
                        }
73✔
1447

1448
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
920✔
1449
                        if err != nil {
920✔
1450
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1451
                        }
×
1452

1453
                        if n > MaxKeyLen {
920✔
1454
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1455
                        }
×
1456

1457
                        indexKeyLen += n
920✔
1458

920✔
1459
                        encodedValues[i+2] = encVal
920✔
1460
                }
1461

1462
                if indexKeyLen > MaxKeyLen {
845✔
1463
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1464
                }
×
1465

1466
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
845✔
1467

845✔
1468
                // no other equivalent entry should be already indexed
845✔
1469
                if index.IsUnique() {
922✔
1470
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1471
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1472
                                return store.ErrKeyAlreadyExists
5✔
1473
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1474
                                return err
×
1475
                        }
×
1476
                }
1477

1478
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
840✔
1479
                if err != nil {
840✔
1480
                        return err
×
1481
                }
×
1482
        }
1483

1484
        tx.updatedRows++
2,294✔
1485

2,294✔
1486
        return nil
2,294✔
1487
}
1488

1489
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
13,975✔
1490
        valbuf := bytes.Buffer{}
13,975✔
1491

13,975✔
1492
        indexKeyLen := 0
13,975✔
1493

13,975✔
1494
        for _, col := range index.cols {
27,962✔
1495
                rval, specified := valuesByColID[col.id]
13,987✔
1496
                if !specified || rval.IsNull() {
13,990✔
1497
                        return nil, ErrPKCanNotBeNull
3✔
1498
                }
3✔
1499

1500
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
13,984✔
1501
                if err != nil {
13,986✔
1502
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1503
                }
2✔
1504

1505
                if n > MaxKeyLen {
13,982✔
1506
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1507
                }
×
1508

1509
                indexKeyLen += n
13,982✔
1510

13,982✔
1511
                _, err = valbuf.Write(encVal)
13,982✔
1512
                if err != nil {
13,982✔
1513
                        return nil, err
×
1514
                }
×
1515
        }
1516

1517
        if indexKeyLen > MaxKeyLen {
13,970✔
1518
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1519
        }
×
1520

1521
        return valbuf.Bytes(), nil
13,970✔
1522
}
1523

1524
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1525
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1526

57✔
1527
        for _, pkCol := range table.primaryIndex.cols {
114✔
1528
                pkVal := valuesByColID[pkCol.id]
57✔
1529

57✔
1530
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1531
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1532
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1533
                }
57✔
1534
        }
57✔
1535

1536
        scanSpecs := &ScanSpecs{
57✔
1537
                Index:         table.primaryIndex,
57✔
1538
                rangesByColID: pkRanges,
57✔
1539
        }
57✔
1540

57✔
1541
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1542
        if err != nil {
57✔
1543
                return nil, err
×
1544
        }
×
1545

1546
        defer func() {
114✔
1547
                r.Close()
57✔
1548
        }()
57✔
1549

1550
        return r.Read(ctx)
57✔
1551
}
1552

1553
// deprecateIndexEntries mark previous index entries as deleted
1554
func (tx *SQLTx) deprecateIndexEntries(
1555
        pkEncVals []byte,
1556
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1557
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1558

36✔
1559
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1560
        if err != nil {
36✔
1561
                return nil, err
×
1562
        }
×
1563

1564
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1565

36✔
1566
        for _, index := range table.indexes {
149✔
1567
                if index.IsPrimary() {
149✔
1568
                        continue
36✔
1569
                }
1570

1571
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1572
                encodedValues[0] = EncodeID(table.id)
77✔
1573
                encodedValues[1] = EncodeID(index.id)
77✔
1574
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1575

77✔
1576
                // existent index entry is deleted only if it differs from existent one
77✔
1577
                sameIndexKey := true
77✔
1578

77✔
1579
                for i, col := range index.cols {
159✔
1580
                        currVal, specified := currValuesByColID[col.id]
82✔
1581
                        if !specified {
82✔
1582
                                currVal = &NullValue{t: col.colType}
×
1583
                        }
×
1584

1585
                        newVal, specified := newValuesByColID[col.id]
82✔
1586
                        if !specified {
86✔
1587
                                newVal = &NullValue{t: col.colType}
4✔
1588
                        }
4✔
1589

1590
                        r, err := currVal.Compare(newVal)
82✔
1591
                        if err != nil {
82✔
1592
                                return nil, err
×
1593
                        }
×
1594

1595
                        sameIndexKey = sameIndexKey && r == 0
82✔
1596

82✔
1597
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1598

82✔
1599
                        encodedValues[i+3] = encVal
82✔
1600
                }
1601

1602
                // mark existent index entry as deleted
1603
                if sameIndexKey {
127✔
1604
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1605
                } else {
77✔
1606
                        md := store.NewKVMetadata()
27✔
1607

27✔
1608
                        md.AsDeleted(true)
27✔
1609

27✔
1610
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1611
                        if err != nil {
27✔
1612
                                return nil, err
×
1613
                        }
×
1614
                }
1615
        }
1616

1617
        return reusableIndexEntries, nil
36✔
1618
}
1619

1620
type UpdateStmt struct {
1621
        tableRef *tableRef
1622
        where    ValueExp
1623
        updates  []*colUpdate
1624
        indexOn  []string
1625
        limit    ValueExp
1626
        offset   ValueExp
1627
}
1628

1629
type colUpdate struct {
1630
        col string
1631
        op  CmpOperator
1632
        val ValueExp
1633
}
1634

1635
func (stmt *UpdateStmt) readOnly() bool {
4✔
1636
        return false
4✔
1637
}
4✔
1638

1639
func (stmt *UpdateStmt) requiredPrivileges() []SQLPrivilege {
4✔
1640
        return []SQLPrivilege{SQLPrivilegeUpdate}
4✔
1641
}
4✔
1642

1643
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1644
        selectStmt := &SelectStmt{
1✔
1645
                ds:    stmt.tableRef,
1✔
1646
                where: stmt.where,
1✔
1647
        }
1✔
1648

1✔
1649
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1650
        if err != nil {
1✔
1651
                return err
×
1652
        }
×
1653

1654
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1655
        if err != nil {
1✔
1656
                return err
×
1657
        }
×
1658

1659
        for _, update := range stmt.updates {
2✔
1660
                col, err := table.GetColumnByName(update.col)
1✔
1661
                if err != nil {
1✔
1662
                        return err
×
1663
                }
×
1664

1665
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1666
                if err != nil {
1✔
1667
                        return err
×
1668
                }
×
1669
        }
1670

1671
        return nil
1✔
1672
}
1673

1674
func (stmt *UpdateStmt) validate(table *Table) error {
21✔
1675
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
21✔
1676

21✔
1677
        for _, update := range stmt.updates {
44✔
1678
                if update.op != EQ {
23✔
1679
                        return ErrIllegalArguments
×
1680
                }
×
1681

1682
                col, err := table.GetColumnByName(update.col)
23✔
1683
                if err != nil {
24✔
1684
                        return err
1✔
1685
                }
1✔
1686

1687
                if table.PrimaryIndex().IncludesCol(col.id) {
22✔
1688
                        return ErrPKCanNotBeUpdated
×
1689
                }
×
1690

1691
                _, duplicated := colIDs[col.id]
22✔
1692
                if duplicated {
22✔
1693
                        return ErrDuplicatedColumn
×
1694
                }
×
1695

1696
                colIDs[col.id] = struct{}{}
22✔
1697
        }
1698

1699
        return nil
20✔
1700
}
1701

1702
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
22✔
1703
        selectStmt := &SelectStmt{
22✔
1704
                ds:      stmt.tableRef,
22✔
1705
                where:   stmt.where,
22✔
1706
                indexOn: stmt.indexOn,
22✔
1707
                limit:   stmt.limit,
22✔
1708
                offset:  stmt.offset,
22✔
1709
        }
22✔
1710

22✔
1711
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
22✔
1712
        if err != nil {
23✔
1713
                return nil, err
1✔
1714
        }
1✔
1715
        defer rowReader.Close()
21✔
1716

21✔
1717
        table := rowReader.ScanSpecs().Index.table
21✔
1718

21✔
1719
        err = stmt.validate(table)
21✔
1720
        if err != nil {
22✔
1721
                return nil, err
1✔
1722
        }
1✔
1723

1724
        cols, err := rowReader.colsBySelector(ctx)
20✔
1725
        if err != nil {
20✔
1726
                return nil, err
×
1727
        }
×
1728

1729
        for {
71✔
1730
                row, err := rowReader.Read(ctx)
51✔
1731
                if errors.Is(err, ErrNoMoreRows) {
68✔
1732
                        break
17✔
1733
                } else if err != nil {
35✔
1734
                        return nil, err
1✔
1735
                }
1✔
1736

1737
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
33✔
1738

33✔
1739
                for _, col := range table.cols {
124✔
1740
                        encSel := EncodeSelector("", table.name, col.colName)
91✔
1741
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
91✔
1742
                }
91✔
1743

1744
                for _, update := range stmt.updates {
68✔
1745
                        col, err := table.GetColumnByName(update.col)
35✔
1746
                        if err != nil {
35✔
1747
                                return nil, err
×
1748
                        }
×
1749

1750
                        sval, err := update.val.substitute(params)
35✔
1751
                        if err != nil {
35✔
1752
                                return nil, err
×
1753
                        }
×
1754

1755
                        rval, err := sval.reduce(tx, row, table.name)
35✔
1756
                        if err != nil {
35✔
1757
                                return nil, err
×
1758
                        }
×
1759

1760
                        err = rval.requiresType(col.colType, cols, nil, table.name)
35✔
1761
                        if err != nil {
35✔
1762
                                return nil, err
×
1763
                        }
×
1764

1765
                        valuesByColID[col.id] = rval
35✔
1766
                }
1767

1768
                for i, col := range table.cols {
124✔
1769
                        v := valuesByColID[col.id]
91✔
1770

91✔
1771
                        row.ValuesByPosition[i] = v
91✔
1772
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
91✔
1773
                }
91✔
1774

1775
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
35✔
1776
                        return nil, err
2✔
1777
                }
2✔
1778

1779
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
31✔
1780
                if err != nil {
31✔
1781
                        return nil, err
×
1782
                }
×
1783

1784
                // primary index entry
1785
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
31✔
1786

31✔
1787
                // mkey must exist
31✔
1788
                _, err = tx.get(ctx, mkey)
31✔
1789
                if err != nil {
31✔
1790
                        return nil, err
×
1791
                }
×
1792

1793
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
31✔
1794
                if err != nil {
31✔
1795
                        return nil, err
×
1796
                }
×
1797
        }
1798

1799
        return tx, nil
17✔
1800
}
1801

1802
type DeleteFromStmt struct {
1803
        tableRef *tableRef
1804
        where    ValueExp
1805
        indexOn  []string
1806
        orderBy  []*OrdExp
1807
        limit    ValueExp
1808
        offset   ValueExp
1809
}
1810

1811
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdExp, limit ValueExp) *DeleteFromStmt {
4✔
1812
        return &DeleteFromStmt{
4✔
1813
                tableRef: NewTableRef(table, ""),
4✔
1814
                where:    where,
4✔
1815
                orderBy:  orderBy,
4✔
1816
                limit:    limit,
4✔
1817
        }
4✔
1818
}
4✔
1819

1820
func (stmt *DeleteFromStmt) readOnly() bool {
1✔
1821
        return false
1✔
1822
}
1✔
1823

1824
func (stmt *DeleteFromStmt) requiredPrivileges() []SQLPrivilege {
1✔
1825
        return []SQLPrivilege{SQLPrivilegeDelete}
1✔
1826
}
1✔
1827

1828
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1829
        selectStmt := &SelectStmt{
1✔
1830
                ds:      stmt.tableRef,
1✔
1831
                where:   stmt.where,
1✔
1832
                orderBy: stmt.orderBy,
1✔
1833
        }
1✔
1834
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1835
}
1✔
1836

1837
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
1838
        selectStmt := &SelectStmt{
15✔
1839
                ds:      stmt.tableRef,
15✔
1840
                where:   stmt.where,
15✔
1841
                indexOn: stmt.indexOn,
15✔
1842
                orderBy: stmt.orderBy,
15✔
1843
                limit:   stmt.limit,
15✔
1844
                offset:  stmt.offset,
15✔
1845
        }
15✔
1846

15✔
1847
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
15✔
1848
        if err != nil {
17✔
1849
                return nil, err
2✔
1850
        }
2✔
1851
        defer rowReader.Close()
13✔
1852

13✔
1853
        table := rowReader.ScanSpecs().Index.table
13✔
1854

13✔
1855
        for {
147✔
1856
                row, err := rowReader.Read(ctx)
134✔
1857
                if errors.Is(err, ErrNoMoreRows) {
146✔
1858
                        break
12✔
1859
                }
1860
                if err != nil {
123✔
1861
                        return nil, err
1✔
1862
                }
1✔
1863

1864
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
121✔
1865

121✔
1866
                for _, col := range table.cols {
406✔
1867
                        encSel := EncodeSelector("", table.name, col.colName)
285✔
1868
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
285✔
1869
                }
285✔
1870

1871
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
121✔
1872
                if err != nil {
121✔
1873
                        return nil, err
×
1874
                }
×
1875

1876
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
121✔
1877
                if err != nil {
121✔
1878
                        return nil, err
×
1879
                }
×
1880

1881
                tx.updatedRows++
121✔
1882
        }
1883
        return tx, nil
12✔
1884
}
1885

1886
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
121✔
1887
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
121✔
1888
        if err != nil {
121✔
1889
                return err
×
1890
        }
×
1891

1892
        for _, index := range table.indexes {
291✔
1893
                if !index.IsPrimary() {
219✔
1894
                        continue
49✔
1895
                }
1896

1897
                encodedValues := make([][]byte, 3+len(index.cols))
121✔
1898
                encodedValues[0] = EncodeID(DatabaseID)
121✔
1899
                encodedValues[1] = EncodeID(table.id)
121✔
1900
                encodedValues[2] = EncodeID(index.id)
121✔
1901

121✔
1902
                for i, col := range index.cols {
242✔
1903
                        val, specified := valuesByColID[col.id]
121✔
1904
                        if !specified {
121✔
1905
                                val = &NullValue{t: col.colType}
×
1906
                        }
×
1907

1908
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
121✔
1909

121✔
1910
                        encodedValues[i+3] = encVal
121✔
1911
                }
1912

1913
                md := store.NewKVMetadata()
121✔
1914

121✔
1915
                md.AsDeleted(true)
121✔
1916

121✔
1917
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
121✔
1918
                if err != nil {
121✔
1919
                        return err
×
1920
                }
×
1921
        }
1922

1923
        return nil
121✔
1924
}
1925

1926
type ValueExp interface {
1927
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1928
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1929
        substitute(params map[string]interface{}) (ValueExp, error)
1930
        selectors() []Selector
1931
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1932
        reduceSelectors(row *Row, implicitTable string) ValueExp
1933
        isConstant() bool
1934
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1935
        String() string
1936
}
1937

1938
type typedValueRange struct {
1939
        lRange *typedValueSemiRange
1940
        hRange *typedValueSemiRange
1941
}
1942

1943
type typedValueSemiRange struct {
1944
        val       TypedValue
1945
        inclusive bool
1946
}
1947

1948
func (r *typedValueRange) unitary() bool {
19✔
1949
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1950
        if r.lRange == nil || r.hRange == nil {
19✔
1951
                return false
×
1952
        }
×
1953

1954
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1955
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1956
}
1957

1958
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
5✔
1959
        if r.lRange == nil {
6✔
1960
                r.lRange = refiningRange.lRange
1✔
1961
        } else if r.lRange != nil && refiningRange.lRange != nil {
6✔
1962
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1963
                if err != nil {
1✔
1964
                        return err
×
1965
                }
×
1966
                r.lRange = maxRange
1✔
1967
        }
1968

1969
        if r.hRange == nil {
8✔
1970
                r.hRange = refiningRange.hRange
3✔
1971
        } else if r.hRange != nil && refiningRange.hRange != nil {
7✔
1972
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1973
                if err != nil {
2✔
1974
                        return err
×
1975
                }
×
1976
                r.hRange = minRange
2✔
1977
        }
1978

1979
        return nil
5✔
1980
}
1981

1982
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1983
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1984
                r.lRange = nil
2✔
1985
        } else {
5✔
1986
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1987
                if err != nil {
3✔
1988
                        return err
×
1989
                }
×
1990
                r.lRange = minRange
3✔
1991
        }
1992

1993
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1994
                r.hRange = nil
3✔
1995
        } else {
5✔
1996
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1997
                if err != nil {
2✔
1998
                        return err
×
1999
                }
×
2000
                r.hRange = maxRange
2✔
2001
        }
2002

2003
        return nil
5✔
2004
}
2005

2006
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
2007
        r, err := or1.val.Compare(or2.val)
3✔
2008
        if err != nil {
3✔
2009
                return nil, err
×
2010
        }
×
2011

2012
        maxVal := or1.val
3✔
2013
        if r < 0 {
5✔
2014
                maxVal = or2.val
2✔
2015
        }
2✔
2016

2017
        return &typedValueSemiRange{
3✔
2018
                val:       maxVal,
3✔
2019
                inclusive: or1.inclusive && or2.inclusive,
3✔
2020
        }, nil
3✔
2021
}
2022

2023
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
2024
        r, err := or1.val.Compare(or2.val)
5✔
2025
        if err != nil {
5✔
2026
                return nil, err
×
2027
        }
×
2028

2029
        minVal := or1.val
5✔
2030
        if r > 0 {
9✔
2031
                minVal = or2.val
4✔
2032
        }
4✔
2033

2034
        return &typedValueSemiRange{
5✔
2035
                val:       minVal,
5✔
2036
                inclusive: or1.inclusive || or2.inclusive,
5✔
2037
        }, nil
5✔
2038
}
2039

2040
type TypedValue interface {
2041
        ValueExp
2042
        Type() SQLValueType
2043
        RawValue() interface{}
2044
        Compare(val TypedValue) (int, error)
2045
        IsNull() bool
2046
}
2047

2048
type Tuple []TypedValue
2049

2050
func (t Tuple) Compare(other Tuple) (int, int, error) {
204,158✔
2051
        if len(t) != len(other) {
204,158✔
2052
                return -1, -1, ErrNotComparableValues
×
2053
        }
×
2054

2055
        for i := range t {
431,257✔
2056
                res, err := t[i].Compare(other[i])
227,099✔
2057
                if err != nil || res != 0 {
420,947✔
2058
                        return res, i, err
193,848✔
2059
                }
193,848✔
2060
        }
2061
        return 0, -1, nil
10,310✔
2062
}
2063

2064
func NewNull(t SQLValueType) *NullValue {
387✔
2065
        return &NullValue{t: t}
387✔
2066
}
387✔
2067

2068
type NullValue struct {
2069
        t SQLValueType
2070
}
2071

2072
func (n *NullValue) Type() SQLValueType {
108✔
2073
        return n.t
108✔
2074
}
108✔
2075

2076
func (n *NullValue) RawValue() interface{} {
353✔
2077
        return nil
353✔
2078
}
353✔
2079

2080
func (n *NullValue) IsNull() bool {
377✔
2081
        return true
377✔
2082
}
377✔
2083

2084
func (n *NullValue) String() string {
4✔
2085
        return "NULL"
4✔
2086
}
4✔
2087

2088
func (n *NullValue) Compare(val TypedValue) (int, error) {
82✔
2089
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
83✔
2090
                return 0, ErrNotComparableValues
1✔
2091
        }
1✔
2092

2093
        if val.RawValue() == nil {
122✔
2094
                return 0, nil
41✔
2095
        }
41✔
2096
        return -1, nil
40✔
2097
}
2098

2099
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
7✔
2100
        return v.t, nil
7✔
2101
}
7✔
2102

2103
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
10✔
2104
        if v.t == t {
15✔
2105
                return nil
5✔
2106
        }
5✔
2107

2108
        if v.t != AnyType {
6✔
2109
                return ErrInvalidTypes
1✔
2110
        }
1✔
2111

2112
        v.t = t
4✔
2113

4✔
2114
        return nil
4✔
2115
}
2116

2117
func (v *NullValue) selectors() []Selector {
13✔
2118
        return nil
13✔
2119
}
13✔
2120

2121
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
394✔
2122
        return v, nil
394✔
2123
}
394✔
2124

2125
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
348✔
2126
        return v, nil
348✔
2127
}
348✔
2128

2129
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
2130
        return v
10✔
2131
}
10✔
2132

2133
func (v *NullValue) isConstant() bool {
12✔
2134
        return true
12✔
2135
}
12✔
2136

2137
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2138
        return nil
1✔
2139
}
1✔
2140

2141
type Integer struct {
2142
        val int64
2143
}
2144

2145
func NewInteger(val int64) *Integer {
315✔
2146
        return &Integer{val: val}
315✔
2147
}
315✔
2148

2149
func (v *Integer) Type() SQLValueType {
302,808✔
2150
        return IntegerType
302,808✔
2151
}
302,808✔
2152

2153
func (v *Integer) IsNull() bool {
116,548✔
2154
        return false
116,548✔
2155
}
116,548✔
2156

2157
func (v *Integer) String() string {
54✔
2158
        return strconv.FormatInt(v.val, 10)
54✔
2159
}
54✔
2160

2161
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
98✔
2162
        return IntegerType, nil
98✔
2163
}
98✔
2164

2165
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
63✔
2166
        if t != IntegerType && t != JSONType {
67✔
2167
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
2168
        }
4✔
2169

2170
        return nil
59✔
2171
}
2172

2173
func (v *Integer) selectors() []Selector {
51✔
2174
        return nil
51✔
2175
}
51✔
2176

2177
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
15,158✔
2178
        return v, nil
15,158✔
2179
}
15,158✔
2180

2181
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
16,907✔
2182
        return v, nil
16,907✔
2183
}
16,907✔
2184

2185
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
9✔
2186
        return v
9✔
2187
}
9✔
2188

2189
func (v *Integer) isConstant() bool {
120✔
2190
        return true
120✔
2191
}
120✔
2192

2193
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2194
        return nil
1✔
2195
}
1✔
2196

2197
func (v *Integer) RawValue() interface{} {
177,191✔
2198
        return v.val
177,191✔
2199
}
177,191✔
2200

2201
func (v *Integer) Compare(val TypedValue) (int, error) {
93,040✔
2202
        if val.IsNull() {
93,083✔
2203
                return 1, nil
43✔
2204
        }
43✔
2205

2206
        if val.Type() == JSONType {
92,998✔
2207
                res, err := val.Compare(v)
1✔
2208
                return -res, err
1✔
2209
        }
1✔
2210

2211
        if val.Type() == Float64Type {
92,996✔
2212
                r, err := val.Compare(v)
×
2213
                return r * -1, err
×
2214
        }
×
2215

2216
        if val.Type() != IntegerType {
93,003✔
2217
                return 0, ErrNotComparableValues
7✔
2218
        }
7✔
2219

2220
        rval := val.RawValue().(int64)
92,989✔
2221

92,989✔
2222
        if v.val == rval {
111,637✔
2223
                return 0, nil
18,648✔
2224
        }
18,648✔
2225

2226
        if v.val > rval {
109,217✔
2227
                return 1, nil
34,876✔
2228
        }
34,876✔
2229

2230
        return -1, nil
39,465✔
2231
}
2232

2233
type Timestamp struct {
2234
        val time.Time
2235
}
2236

2237
func (v *Timestamp) Type() SQLValueType {
40,315✔
2238
        return TimestampType
40,315✔
2239
}
40,315✔
2240

2241
func (v *Timestamp) IsNull() bool {
32,894✔
2242
        return false
32,894✔
2243
}
32,894✔
2244

2245
func (v *Timestamp) String() string {
1✔
2246
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2247
}
1✔
2248

2249
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2250
        return TimestampType, nil
1✔
2251
}
1✔
2252

2253
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
14✔
2254
        if t != TimestampType {
15✔
2255
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2256
        }
1✔
2257

2258
        return nil
13✔
2259
}
2260

2261
func (v *Timestamp) selectors() []Selector {
1✔
2262
        return nil
1✔
2263
}
1✔
2264

2265
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
1,165✔
2266
        return v, nil
1,165✔
2267
}
1,165✔
2268

2269
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,025✔
2270
        return v, nil
2,025✔
2271
}
2,025✔
2272

2273
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2274
        return v
1✔
2275
}
1✔
2276

2277
func (v *Timestamp) isConstant() bool {
1✔
2278
        return true
1✔
2279
}
1✔
2280

2281
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2282
        return nil
1✔
2283
}
1✔
2284

2285
func (v *Timestamp) RawValue() interface{} {
57,410✔
2286
        return v.val
57,410✔
2287
}
57,410✔
2288

2289
func (v *Timestamp) Compare(val TypedValue) (int, error) {
29,744✔
2290
        if val.IsNull() {
29,746✔
2291
                return 1, nil
2✔
2292
        }
2✔
2293

2294
        if val.Type() != TimestampType {
29,743✔
2295
                return 0, ErrNotComparableValues
1✔
2296
        }
1✔
2297

2298
        rval := val.RawValue().(time.Time)
29,741✔
2299

29,741✔
2300
        if v.val.Before(rval) {
44,251✔
2301
                return -1, nil
14,510✔
2302
        }
14,510✔
2303

2304
        if v.val.After(rval) {
30,271✔
2305
                return 1, nil
15,040✔
2306
        }
15,040✔
2307

2308
        return 0, nil
191✔
2309
}
2310

2311
type Varchar struct {
2312
        val string
2313
}
2314

2315
func NewVarchar(val string) *Varchar {
2,051✔
2316
        return &Varchar{val: val}
2,051✔
2317
}
2,051✔
2318

2319
func (v *Varchar) Type() SQLValueType {
127,944✔
2320
        return VarcharType
127,944✔
2321
}
127,944✔
2322

2323
func (v *Varchar) IsNull() bool {
64,571✔
2324
        return false
64,571✔
2325
}
64,571✔
2326

2327
func (v *Varchar) String() string {
18✔
2328
        return fmt.Sprintf("'%s'", v.val)
18✔
2329
}
18✔
2330

2331
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
60✔
2332
        return VarcharType, nil
60✔
2333
}
60✔
2334

2335
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
140✔
2336
        if t != VarcharType && t != JSONType {
142✔
2337
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
2✔
2338
        }
2✔
2339
        return nil
138✔
2340
}
2341

2342
func (v *Varchar) selectors() []Selector {
39✔
2343
        return nil
39✔
2344
}
39✔
2345

2346
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
4,949✔
2347
        return v, nil
4,949✔
2348
}
4,949✔
2349

2350
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5,635✔
2351
        return v, nil
5,635✔
2352
}
5,635✔
2353

2354
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2355
        return v
×
2356
}
×
2357

2358
func (v *Varchar) isConstant() bool {
39✔
2359
        return true
39✔
2360
}
39✔
2361

2362
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2363
        return nil
1✔
2364
}
1✔
2365

2366
func (v *Varchar) RawValue() interface{} {
90,157✔
2367
        return v.val
90,157✔
2368
}
90,157✔
2369

2370
func (v *Varchar) Compare(val TypedValue) (int, error) {
58,120✔
2371
        if val.IsNull() {
58,174✔
2372
                return 1, nil
54✔
2373
        }
54✔
2374

2375
        if val.Type() == JSONType {
59,067✔
2376
                res, err := val.Compare(v)
1,001✔
2377
                return -res, err
1,001✔
2378
        }
1,001✔
2379

2380
        if val.Type() != VarcharType {
57,066✔
2381
                return 0, ErrNotComparableValues
1✔
2382
        }
1✔
2383

2384
        rval := val.RawValue().(string)
57,064✔
2385

57,064✔
2386
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
57,064✔
2387
}
2388

2389
type UUID struct {
2390
        val uuid.UUID
2391
}
2392

2393
func NewUUID(val uuid.UUID) *UUID {
1✔
2394
        return &UUID{val: val}
1✔
2395
}
1✔
2396

2397
func (v *UUID) Type() SQLValueType {
10✔
2398
        return UUIDType
10✔
2399
}
10✔
2400

2401
func (v *UUID) IsNull() bool {
26✔
2402
        return false
26✔
2403
}
26✔
2404

2405
func (v *UUID) String() string {
1✔
2406
        return v.val.String()
1✔
2407
}
1✔
2408

2409
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2410
        return UUIDType, nil
1✔
2411
}
1✔
2412

2413
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2414
        if t != UUIDType {
6✔
2415
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2416
        }
2✔
2417

2418
        return nil
2✔
2419
}
2420

2421
func (v *UUID) selectors() []Selector {
1✔
2422
        return nil
1✔
2423
}
1✔
2424

2425
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
6✔
2426
        return v, nil
6✔
2427
}
6✔
2428

2429
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5✔
2430
        return v, nil
5✔
2431
}
5✔
2432

2433
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2434
        return v
1✔
2435
}
1✔
2436

2437
func (v *UUID) isConstant() bool {
1✔
2438
        return true
1✔
2439
}
1✔
2440

2441
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2442
        return nil
1✔
2443
}
1✔
2444

2445
func (v *UUID) RawValue() interface{} {
41✔
2446
        return v.val
41✔
2447
}
41✔
2448

2449
func (v *UUID) Compare(val TypedValue) (int, error) {
5✔
2450
        if val.IsNull() {
7✔
2451
                return 1, nil
2✔
2452
        }
2✔
2453

2454
        if val.Type() != UUIDType {
4✔
2455
                return 0, ErrNotComparableValues
1✔
2456
        }
1✔
2457

2458
        rval := val.RawValue().(uuid.UUID)
2✔
2459

2✔
2460
        return bytes.Compare(v.val[:], rval[:]), nil
2✔
2461
}
2462

2463
type Bool struct {
2464
        val bool
2465
}
2466

2467
func NewBool(val bool) *Bool {
208✔
2468
        return &Bool{val: val}
208✔
2469
}
208✔
2470

2471
func (v *Bool) Type() SQLValueType {
1,964✔
2472
        return BooleanType
1,964✔
2473
}
1,964✔
2474

2475
func (v *Bool) IsNull() bool {
1,479✔
2476
        return false
1,479✔
2477
}
1,479✔
2478

2479
func (v *Bool) String() string {
41✔
2480
        return strconv.FormatBool(v.val)
41✔
2481
}
41✔
2482

2483
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
32✔
2484
        return BooleanType, nil
32✔
2485
}
32✔
2486

2487
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
57✔
2488
        if t != BooleanType && t != JSONType {
62✔
2489
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2490
        }
5✔
2491
        return nil
52✔
2492
}
2493

2494
func (v *Bool) selectors() []Selector {
5✔
2495
        return nil
5✔
2496
}
5✔
2497

2498
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
639✔
2499
        return v, nil
639✔
2500
}
639✔
2501

2502
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2503
        return v, nil
726✔
2504
}
726✔
2505

2506
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2507
        return v
×
2508
}
×
2509

2510
func (v *Bool) isConstant() bool {
3✔
2511
        return true
3✔
2512
}
3✔
2513

2514
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
2515
        return nil
8✔
2516
}
8✔
2517

2518
func (v *Bool) RawValue() interface{} {
1,710✔
2519
        return v.val
1,710✔
2520
}
1,710✔
2521

2522
func (v *Bool) Compare(val TypedValue) (int, error) {
575✔
2523
        if val.IsNull() {
605✔
2524
                return 1, nil
30✔
2525
        }
30✔
2526

2527
        if val.Type() == JSONType {
546✔
2528
                res, err := val.Compare(v)
1✔
2529
                return -res, err
1✔
2530
        }
1✔
2531

2532
        if val.Type() != BooleanType {
544✔
2533
                return 0, ErrNotComparableValues
×
2534
        }
×
2535

2536
        rval := val.RawValue().(bool)
544✔
2537

544✔
2538
        if v.val == rval {
888✔
2539
                return 0, nil
344✔
2540
        }
344✔
2541

2542
        if v.val {
206✔
2543
                return 1, nil
6✔
2544
        }
6✔
2545

2546
        return -1, nil
194✔
2547
}
2548

2549
type Blob struct {
2550
        val []byte
2551
}
2552

2553
func NewBlob(val []byte) *Blob {
286✔
2554
        return &Blob{val: val}
286✔
2555
}
286✔
2556

2557
func (v *Blob) Type() SQLValueType {
53✔
2558
        return BLOBType
53✔
2559
}
53✔
2560

2561
func (v *Blob) IsNull() bool {
2,312✔
2562
        return false
2,312✔
2563
}
2,312✔
2564

2565
func (v *Blob) String() string {
2✔
2566
        return hex.EncodeToString(v.val)
2✔
2567
}
2✔
2568

2569
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2570
        return BLOBType, nil
1✔
2571
}
1✔
2572

2573
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2574
        if t != BLOBType {
3✔
2575
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2576
        }
1✔
2577

2578
        return nil
1✔
2579
}
2580

2581
func (v *Blob) selectors() []Selector {
1✔
2582
        return nil
1✔
2583
}
1✔
2584

2585
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
714✔
2586
        return v, nil
714✔
2587
}
714✔
2588

2589
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2590
        return v, nil
726✔
2591
}
726✔
2592

2593
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2594
        return v
×
2595
}
×
2596

2597
func (v *Blob) isConstant() bool {
7✔
2598
        return true
7✔
2599
}
7✔
2600

2601
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2602
        return nil
×
2603
}
×
2604

2605
func (v *Blob) RawValue() interface{} {
2,570✔
2606
        return v.val
2,570✔
2607
}
2,570✔
2608

2609
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
2610
        if val.IsNull() {
25✔
2611
                return 1, nil
×
2612
        }
×
2613

2614
        if val.Type() != BLOBType {
25✔
2615
                return 0, ErrNotComparableValues
×
2616
        }
×
2617

2618
        rval := val.RawValue().([]byte)
25✔
2619

25✔
2620
        return bytes.Compare(v.val, rval), nil
25✔
2621
}
2622

2623
type Float64 struct {
2624
        val float64
2625
}
2626

2627
func NewFloat64(val float64) *Float64 {
1,249✔
2628
        return &Float64{val: val}
1,249✔
2629
}
1,249✔
2630

2631
func (v *Float64) Type() SQLValueType {
207,436✔
2632
        return Float64Type
207,436✔
2633
}
207,436✔
2634

2635
func (v *Float64) IsNull() bool {
5,598✔
2636
        return false
5,598✔
2637
}
5,598✔
2638

2639
func (v *Float64) String() string {
3✔
2640
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2641
}
3✔
2642

2643
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2644
        return Float64Type, nil
15✔
2645
}
15✔
2646

2647
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2648
        if t != Float64Type && t != JSONType {
21✔
2649
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2650
        }
1✔
2651
        return nil
19✔
2652
}
2653

2654
func (v *Float64) selectors() []Selector {
2✔
2655
        return nil
2✔
2656
}
2✔
2657

2658
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
1,817✔
2659
        return v, nil
1,817✔
2660
}
1,817✔
2661

2662
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,428✔
2663
        return v, nil
3,428✔
2664
}
3,428✔
2665

2666
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2667
        return v
1✔
2668
}
1✔
2669

2670
func (v *Float64) isConstant() bool {
5✔
2671
        return true
5✔
2672
}
5✔
2673

2674
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2675
        return nil
1✔
2676
}
1✔
2677

2678
func (v *Float64) RawValue() interface{} {
365,967✔
2679
        return v.val
365,967✔
2680
}
365,967✔
2681

2682
func (v *Float64) Compare(val TypedValue) (int, error) {
61,876✔
2683
        if val.Type() == JSONType {
61,877✔
2684
                res, err := val.Compare(v)
1✔
2685
                return -res, err
1✔
2686
        }
1✔
2687

2688
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
61,875✔
2689
        if err != nil {
61,876✔
2690
                return 0, err
1✔
2691
        }
1✔
2692

2693
        if convVal == nil {
61,877✔
2694
                return 1, nil
3✔
2695
        }
3✔
2696

2697
        rval, ok := convVal.(float64)
61,871✔
2698
        if !ok {
61,871✔
2699
                return 0, ErrNotComparableValues
×
2700
        }
×
2701

2702
        if v.val == rval {
61,996✔
2703
                return 0, nil
125✔
2704
        }
125✔
2705

2706
        if v.val > rval {
90,619✔
2707
                return 1, nil
28,873✔
2708
        }
28,873✔
2709

2710
        return -1, nil
32,873✔
2711
}
2712

2713
type FnCall struct {
2714
        fn     string
2715
        params []ValueExp
2716
}
2717

2718
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2719
        fn, err := v.resolveFunc()
23✔
2720
        if err != nil {
24✔
2721
                return AnyType, nil
1✔
2722
        }
1✔
2723
        return fn.InferType(cols, params, implicitTable)
22✔
2724
}
2725

2726
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
18✔
2727
        fn, err := v.resolveFunc()
18✔
2728
        if err != nil {
19✔
2729
                return err
1✔
2730
        }
1✔
2731
        return fn.RequiresType(t, cols, params, implicitTable)
17✔
2732
}
2733

2734
func (v *FnCall) selectors() []Selector {
33✔
2735
        selectors := make([]Selector, 0)
33✔
2736
        for _, param := range v.params {
89✔
2737
                selectors = append(selectors, param.selectors()...)
56✔
2738
        }
56✔
2739
        return selectors
33✔
2740
}
2741

2742
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
433✔
2743
        ps := make([]ValueExp, len(v.params))
433✔
2744
        for i, p := range v.params {
791✔
2745
                ps[i], err = p.substitute(params)
358✔
2746
                if err != nil {
358✔
2747
                        return nil, err
×
2748
                }
×
2749
        }
2750

2751
        return &FnCall{
433✔
2752
                fn:     v.fn,
433✔
2753
                params: ps,
433✔
2754
        }, nil
433✔
2755
}
2756

2757
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
433✔
2758
        fn, err := v.resolveFunc()
433✔
2759
        if err != nil {
434✔
2760
                return nil, err
1✔
2761
        }
1✔
2762

2763
        fnInputs, err := v.reduceParams(tx, row, implicitTable)
432✔
2764
        if err != nil {
432✔
2765
                return nil, err
×
2766
        }
×
2767
        return fn.Apply(tx, fnInputs)
432✔
2768
}
2769

2770
func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) {
432✔
2771
        var values []TypedValue
432✔
2772
        if len(v.params) > 0 {
763✔
2773
                values = make([]TypedValue, len(v.params))
331✔
2774
                for i, p := range v.params {
689✔
2775
                        v, err := p.reduce(tx, row, implicitTable)
358✔
2776
                        if err != nil {
358✔
2777
                                return nil, err
×
2778
                        }
×
2779
                        values[i] = v
358✔
2780
                }
2781
        }
2782
        return values, nil
432✔
2783
}
2784

2785
func (v *FnCall) resolveFunc() (Function, error) {
474✔
2786
        fn, exists := builtinFunctions[strings.ToUpper(v.fn)]
474✔
2787
        if !exists {
477✔
2788
                return nil, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
3✔
2789
        }
3✔
2790
        return fn, nil
471✔
2791
}
2792

2793
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2794
        return v
×
2795
}
×
2796

2797
func (v *FnCall) isConstant() bool {
13✔
2798
        return false
13✔
2799
}
13✔
2800

2801
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2802
        return nil
×
2803
}
×
2804

2805
func (v *FnCall) String() string {
1✔
2806
        params := make([]string, len(v.params))
1✔
2807
        for i, p := range v.params {
4✔
2808
                params[i] = p.String()
3✔
2809
        }
3✔
2810
        return v.fn + "(" + strings.Join(params, ",") + ")"
1✔
2811
}
2812

2813
type Cast struct {
2814
        val ValueExp
2815
        t   SQLValueType
2816
}
2817

2818
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2819
        _, err := c.val.inferType(cols, params, implicitTable)
23✔
2820
        if err != nil {
24✔
2821
                return AnyType, err
1✔
2822
        }
1✔
2823

2824
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2825

2826
        return c.t, nil
22✔
2827
}
2828

2829
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2830
        if c.t != t {
×
2831
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2832
        }
×
2833

2834
        return nil
×
2835
}
2836

2837
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
281✔
2838
        val, err := c.val.substitute(params)
281✔
2839
        if err != nil {
281✔
2840
                return nil, err
×
2841
        }
×
2842
        c.val = val
281✔
2843
        return c, nil
281✔
2844
}
2845

2846
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
269✔
2847
        val, err := c.val.reduce(tx, row, implicitTable)
269✔
2848
        if err != nil {
269✔
2849
                return nil, err
×
2850
        }
×
2851

2852
        conv, err := getConverter(val.Type(), c.t)
269✔
2853
        if conv == nil {
272✔
2854
                return nil, err
3✔
2855
        }
3✔
2856

2857
        return conv(val)
266✔
2858
}
2859

2860
func (v *Cast) selectors() []Selector {
6✔
2861
        return v.val.selectors()
6✔
2862
}
6✔
2863

2864
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2865
        return &Cast{
×
2866
                val: c.val.reduceSelectors(row, implicitTable),
×
2867
                t:   c.t,
×
2868
        }
×
2869
}
×
2870

2871
func (c *Cast) isConstant() bool {
7✔
2872
        return c.val.isConstant()
7✔
2873
}
7✔
2874

2875
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2876
        return nil
×
2877
}
×
2878

2879
func (c *Cast) String() string {
1✔
2880
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2881
}
1✔
2882

2883
type Param struct {
2884
        id  string
2885
        pos int
2886
}
2887

2888
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
58✔
2889
        t, ok := params[v.id]
58✔
2890
        if !ok {
114✔
2891
                params[v.id] = AnyType
56✔
2892
                return AnyType, nil
56✔
2893
        }
56✔
2894

2895
        return t, nil
2✔
2896
}
2897

2898
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2899
        currT, ok := params[v.id]
76✔
2900
        if ok && currT != t && currT != AnyType {
80✔
2901
                return ErrInferredMultipleTypes
4✔
2902
        }
4✔
2903

2904
        params[v.id] = t
72✔
2905

72✔
2906
        return nil
72✔
2907
}
2908

2909
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
6,399✔
2910
        val, ok := params[p.id]
6,399✔
2911
        if !ok {
6,461✔
2912
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2913
        }
62✔
2914

2915
        if val == nil {
6,386✔
2916
                return &NullValue{t: AnyType}, nil
49✔
2917
        }
49✔
2918

2919
        switch v := val.(type) {
6,288✔
2920
        case bool:
96✔
2921
                {
192✔
2922
                        return &Bool{val: v}, nil
96✔
2923
                }
96✔
2924
        case string:
1,752✔
2925
                {
3,504✔
2926
                        return &Varchar{val: v}, nil
1,752✔
2927
                }
1,752✔
2928
        case int:
1,678✔
2929
                {
3,356✔
2930
                        return &Integer{val: int64(v)}, nil
1,678✔
2931
                }
1,678✔
2932
        case uint:
×
2933
                {
×
2934
                        return &Integer{val: int64(v)}, nil
×
2935
                }
×
2936
        case uint64:
34✔
2937
                {
68✔
2938
                        return &Integer{val: int64(v)}, nil
34✔
2939
                }
34✔
2940
        case int64:
227✔
2941
                {
454✔
2942
                        return &Integer{val: v}, nil
227✔
2943
                }
227✔
2944
        case []byte:
14✔
2945
                {
28✔
2946
                        return &Blob{val: v}, nil
14✔
2947
                }
14✔
2948
        case time.Time:
861✔
2949
                {
1,722✔
2950
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
861✔
2951
                }
861✔
2952
        case float64:
1,625✔
2953
                {
3,250✔
2954
                        return &Float64{val: v}, nil
1,625✔
2955
                }
1,625✔
2956
        }
2957
        return nil, ErrUnsupportedParameter
1✔
2958
}
2959

2960
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2961
        return nil, ErrUnexpected
×
2962
}
×
2963

2964
func (p *Param) selectors() []Selector {
4✔
2965
        return nil
4✔
2966
}
4✔
2967

2968
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2969
        return p
×
2970
}
×
2971

2972
func (p *Param) isConstant() bool {
130✔
2973
        return true
130✔
2974
}
130✔
2975

2976
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
6✔
2977
        return nil
6✔
2978
}
6✔
2979

2980
func (v *Param) String() string {
2✔
2981
        return "@" + v.id
2✔
2982
}
2✔
2983

2984
type whenThenClause struct {
2985
        when, then ValueExp
2986
}
2987

2988
type CaseWhenExp struct {
2989
        exp      ValueExp
2990
        whenThen []whenThenClause
2991
        elseExp  ValueExp
2992
}
2993

2994
func (ce *CaseWhenExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
8✔
2995
        checkType := func(e ValueExp, expectedType SQLValueType) (string, error) {
18✔
2996
                t, err := e.inferType(cols, params, implicitTable)
10✔
2997
                if err != nil {
10✔
2998
                        return "", err
×
2999
                }
×
3000

3001
                if expectedType == AnyType {
15✔
3002
                        return t, nil
5✔
3003
                }
5✔
3004

3005
                if t != expectedType {
6✔
3006
                        if (t == Float64Type && expectedType == IntegerType) ||
1✔
3007
                                (t == IntegerType && expectedType == Float64Type) {
1✔
3008
                                return Float64Type, nil
×
3009
                        }
×
3010
                        return "", fmt.Errorf("%w: CASE types %s and %s cannot be matched", ErrInferredMultipleTypes, expectedType, t)
1✔
3011
                }
3012
                return t, nil
4✔
3013
        }
3014

3015
        searchType := BooleanType
8✔
3016
        inferredResType := AnyType
8✔
3017
        if ce.exp != nil {
11✔
3018
                t, err := ce.exp.inferType(cols, params, implicitTable)
3✔
3019
                if err != nil {
3✔
3020
                        return "", err
×
3021
                }
×
3022
                searchType = t
3✔
3023
        }
3024

3025
        for _, e := range ce.whenThen {
16✔
3026
                whenType, err := e.when.inferType(cols, params, implicitTable)
8✔
3027
                if err != nil {
8✔
3028
                        return "", err
×
3029
                }
×
3030

3031
                if whenType != searchType {
11✔
3032
                        return "", fmt.Errorf("%w: argument of CASE/WHEN must be of type %s, not type %s", ErrInvalidTypes, searchType, whenType)
3✔
3033
                }
3✔
3034

3035
                t, err := checkType(e.then, inferredResType)
5✔
3036
                if err != nil {
5✔
3037
                        return "", err
×
3038
                }
×
3039
                inferredResType = t
5✔
3040
        }
3041

3042
        if ce.elseExp != nil {
10✔
3043
                return checkType(ce.elseExp, inferredResType)
5✔
3044
        }
5✔
3045
        return inferredResType, nil
×
3046
}
3047

3048
func (ce *CaseWhenExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
3049
        inferredType, err := ce.inferType(cols, params, implicitTable)
2✔
3050
        if err != nil {
3✔
3051
                return err
1✔
3052
        }
1✔
3053

3054
        if inferredType != t {
1✔
3055
                return fmt.Errorf("%w: expected type %s but %s found instead", ErrInvalidTypes, t, inferredType)
×
3056
        }
×
3057
        return nil
1✔
3058
}
3059

3060
func (ce *CaseWhenExp) substitute(params map[string]interface{}) (ValueExp, error) {
504✔
3061
        var exp ValueExp
504✔
3062
        if ce.exp != nil {
605✔
3063
                e, err := ce.exp.substitute(params)
101✔
3064
                if err != nil {
101✔
3065
                        return nil, err
×
3066
                }
×
3067
                exp = e
101✔
3068
        }
3069

3070
        whenThen := make([]whenThenClause, len(ce.whenThen))
504✔
3071
        for i, wt := range ce.whenThen {
1,208✔
3072
                whenValue, err := wt.when.substitute(params)
704✔
3073
                if err != nil {
704✔
3074
                        return nil, err
×
3075
                }
×
3076
                whenThen[i].when = whenValue
704✔
3077

704✔
3078
                thenValue, err := wt.then.substitute(params)
704✔
3079
                if err != nil {
704✔
3080
                        return nil, err
×
3081
                }
×
3082
                whenThen[i].then = thenValue
704✔
3083
        }
3084

3085
        if ce.elseExp == nil {
506✔
3086
                return &CaseWhenExp{
2✔
3087
                        exp:      exp,
2✔
3088
                        whenThen: whenThen,
2✔
3089
                }, nil
2✔
3090
        }
2✔
3091

3092
        elseValue, err := ce.elseExp.substitute(params)
502✔
3093
        if err != nil {
502✔
3094
                return nil, err
×
3095
        }
×
3096
        return &CaseWhenExp{
502✔
3097
                exp:      exp,
502✔
3098
                whenThen: whenThen,
502✔
3099
                elseExp:  elseValue,
502✔
3100
        }, nil
502✔
3101
}
3102

3103
func (ce *CaseWhenExp) selectors() []Selector {
7✔
3104
        selectors := make([]Selector, 0)
7✔
3105
        if ce.exp != nil {
8✔
3106
                selectors = append(selectors, ce.exp.selectors()...)
1✔
3107
        }
1✔
3108

3109
        for _, wh := range ce.whenThen {
16✔
3110
                selectors = append(selectors, wh.when.selectors()...)
9✔
3111
                selectors = append(selectors, wh.then.selectors()...)
9✔
3112
        }
9✔
3113

3114
        if ce.elseExp == nil {
9✔
3115
                return selectors
2✔
3116
        }
2✔
3117
        return append(selectors, ce.elseExp.selectors()...)
5✔
3118
}
3119

3120
func (ce *CaseWhenExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
304✔
3121
        var searchValue TypedValue
304✔
3122
        if ce.exp != nil {
405✔
3123
                v, err := ce.exp.reduce(tx, row, implicitTable)
101✔
3124
                if err != nil {
101✔
3125
                        return nil, err
×
3126
                }
×
3127
                searchValue = v
101✔
3128
        } else {
203✔
3129
                searchValue = &Bool{val: true}
203✔
3130
        }
203✔
3131

3132
        for _, wt := range ce.whenThen {
734✔
3133
                v, err := wt.when.reduce(tx, row, implicitTable)
430✔
3134
                if err != nil {
430✔
3135
                        return nil, err
×
3136
                }
×
3137

3138
                if v.Type() != searchValue.Type() {
431✔
3139
                        return nil, fmt.Errorf("%w: argument of CASE/WHEN must be type %s, not type %s", ErrInvalidTypes, v.Type(), searchValue.Type())
1✔
3140
                }
1✔
3141

3142
                res, err := v.Compare(searchValue)
429✔
3143
                if err != nil {
429✔
3144
                        return nil, err
×
3145
                }
×
3146
                if res == 0 {
629✔
3147
                        return wt.then.reduce(tx, row, implicitTable)
200✔
3148
                }
200✔
3149
        }
3150

3151
        if ce.elseExp == nil {
104✔
3152
                return NewNull(AnyType), nil
1✔
3153
        }
1✔
3154
        return ce.elseExp.reduce(tx, row, implicitTable)
102✔
3155
}
3156

3157
func (ce *CaseWhenExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3158
        whenThen := make([]whenThenClause, len(ce.whenThen))
1✔
3159
        for i, wt := range ce.whenThen {
2✔
3160
                whenValue := wt.when.reduceSelectors(row, implicitTable)
1✔
3161
                whenThen[i].when = whenValue
1✔
3162

1✔
3163
                thenValue := wt.then.reduceSelectors(row, implicitTable)
1✔
3164
                whenThen[i].then = thenValue
1✔
3165
        }
1✔
3166

3167
        if ce.elseExp == nil {
1✔
3168
                return &CaseWhenExp{
×
3169
                        whenThen: whenThen,
×
3170
                }
×
3171
        }
×
3172

3173
        return &CaseWhenExp{
1✔
3174
                whenThen: whenThen,
1✔
3175
                elseExp:  ce.elseExp.reduceSelectors(row, implicitTable),
1✔
3176
        }
1✔
3177
}
3178

3179
func (ce *CaseWhenExp) isConstant() bool {
1✔
3180
        return false
1✔
3181
}
1✔
3182

3183
func (ce *CaseWhenExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
3184
        return nil
1✔
3185
}
1✔
3186

3187
func (ce *CaseWhenExp) String() string {
3✔
3188
        var sb strings.Builder
3✔
3189
        for _, wh := range ce.whenThen {
7✔
3190
                sb.WriteString(fmt.Sprintf("WHEN %s THEN %s ", wh.when.String(), wh.then.String()))
4✔
3191
        }
4✔
3192

3193
        if ce.elseExp != nil {
5✔
3194
                sb.WriteString("ELSE " + ce.elseExp.String() + " ")
2✔
3195
        }
2✔
3196
        return "CASE " + sb.String() + "END"
3✔
3197
}
3198

3199
type Comparison int
3200

3201
const (
3202
        EqualTo Comparison = iota
3203
        LowerThan
3204
        LowerOrEqualTo
3205
        GreaterThan
3206
        GreaterOrEqualTo
3207
)
3208

3209
type DataSource interface {
3210
        SQLStmt
3211
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
3212
        Alias() string
3213
}
3214

3215
type TargetEntry struct {
3216
        Exp ValueExp
3217
        As  string
3218
}
3219

3220
type SelectStmt struct {
3221
        distinct  bool
3222
        targets   []TargetEntry
3223
        selectors []Selector
3224
        ds        DataSource
3225
        indexOn   []string
3226
        joins     []*JoinSpec
3227
        where     ValueExp
3228
        groupBy   []*ColSelector
3229
        having    ValueExp
3230
        orderBy   []*OrdExp
3231
        limit     ValueExp
3232
        offset    ValueExp
3233
        as        string
3234
}
3235

3236
func NewSelectStmt(
3237
        targets []TargetEntry,
3238
        ds DataSource,
3239
        where ValueExp,
3240
        orderBy []*OrdExp,
3241
        limit ValueExp,
3242
        offset ValueExp,
3243
) *SelectStmt {
71✔
3244
        return &SelectStmt{
71✔
3245
                targets: targets,
71✔
3246
                ds:      ds,
71✔
3247
                where:   where,
71✔
3248
                orderBy: orderBy,
71✔
3249
                limit:   limit,
71✔
3250
                offset:  offset,
71✔
3251
        }
71✔
3252
}
71✔
3253

3254
func (stmt *SelectStmt) readOnly() bool {
94✔
3255
        return true
94✔
3256
}
94✔
3257

3258
func (stmt *SelectStmt) requiredPrivileges() []SQLPrivilege {
96✔
3259
        return []SQLPrivilege{SQLPrivilegeSelect}
96✔
3260
}
96✔
3261

3262
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
53✔
3263
        _, err := stmt.execAt(ctx, tx, nil)
53✔
3264
        if err != nil {
53✔
3265
                return err
×
3266
        }
×
3267

3268
        // TODO: (jeroiraz) may be optimized so to resolve the query statement just once
3269
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
53✔
3270
        if err != nil {
54✔
3271
                return err
1✔
3272
        }
1✔
3273
        defer rowReader.Close()
52✔
3274

52✔
3275
        return rowReader.InferParameters(ctx, params)
52✔
3276
}
3277

3278
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
621✔
3279
        if stmt.groupBy == nil && stmt.having != nil {
622✔
3280
                return nil, ErrHavingClauseRequiresGroupClause
1✔
3281
        }
1✔
3282

3283
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
705✔
3284
                for _, sel := range stmt.targetSelectors() {
227✔
3285
                        _, isAgg := sel.(*AggColSelector)
142✔
3286
                        if !isAgg && !stmt.groupByContains(sel) {
144✔
3287
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3288
                        }
2✔
3289
                }
3290
        }
3291

3292
        if len(stmt.orderBy) > 0 {
775✔
3293
                for _, col := range stmt.orderBy {
352✔
3294
                        for _, sel := range col.exp.selectors() {
380✔
3295
                                _, isAgg := sel.(*AggColSelector)
185✔
3296
                                if (isAgg && !stmt.selectorAppearsInTargets(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
187✔
3297
                                        return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3298
                                }
2✔
3299
                        }
3300
                }
3301
        }
3302
        return tx, nil
616✔
3303
}
3304

3305
func (stmt *SelectStmt) targetSelectors() []Selector {
2,524✔
3306
        if stmt.selectors == nil {
3,476✔
3307
                stmt.selectors = stmt.extractSelectors()
952✔
3308
        }
952✔
3309
        return stmt.selectors
2,524✔
3310
}
3311

3312
func (stmt *SelectStmt) selectorAppearsInTargets(s Selector) bool {
4✔
3313
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
3314

4✔
3315
        for _, sel := range stmt.targetSelectors() {
12✔
3316
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
11✔
3317
                        return true
3✔
3318
                }
3✔
3319
        }
3320
        return false
1✔
3321
}
3322

3323
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
3324
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
3325

57✔
3326
        for _, colSel := range stmt.groupBy {
137✔
3327
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
134✔
3328
                        return true
54✔
3329
                }
54✔
3330
        }
3331
        return false
3✔
3332
}
3333

3334
func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector {
80✔
3335
        cols := make([]*AggColSelector, 0, len(stmt.targets))
80✔
3336

80✔
3337
        for _, t := range stmt.targets {
214✔
3338
                selectors := t.Exp.selectors()
134✔
3339
                for _, sel := range selectors {
268✔
3340
                        aggSel, isAgg := sel.(*AggColSelector)
134✔
3341
                        if isAgg {
244✔
3342
                                cols = append(cols, aggSel)
110✔
3343
                        }
110✔
3344
                }
3345
        }
3346
        return cols
80✔
3347
}
3348

3349
func (stmt *SelectStmt) extractSelectors() []Selector {
952✔
3350
        selectors := make([]Selector, 0, len(stmt.targets))
952✔
3351
        for _, t := range stmt.targets {
1,835✔
3352
                selectors = append(selectors, t.Exp.selectors()...)
883✔
3353
        }
883✔
3354
        return selectors
952✔
3355
}
3356

3357
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
959✔
3358
        scanSpecs, err := stmt.genScanSpecs(tx, params)
959✔
3359
        if err != nil {
975✔
3360
                return nil, err
16✔
3361
        }
16✔
3362

3363
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
943✔
3364
        if err != nil {
946✔
3365
                return nil, err
3✔
3366
        }
3✔
3367
        defer func() {
1,880✔
3368
                if err != nil {
947✔
3369
                        rowReader.Close()
7✔
3370
                }
7✔
3371
        }()
3372

3373
        if stmt.joins != nil {
955✔
3374
                var jointRowReader *jointRowReader
15✔
3375
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
15✔
3376
                if err != nil {
16✔
3377
                        return nil, err
1✔
3378
                }
1✔
3379
                rowReader = jointRowReader
14✔
3380
        }
3381

3382
        if stmt.where != nil {
1,468✔
3383
                rowReader = newConditionalRowReader(rowReader, stmt.where)
529✔
3384
        }
529✔
3385

3386
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
1,019✔
3387
                if len(scanSpecs.groupBySortExps) > 0 {
92✔
3388
                        var sortRowReader *sortRowReader
12✔
3389
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortExps)
12✔
3390
                        if err != nil {
12✔
3391
                                return nil, err
×
3392
                        }
×
3393
                        rowReader = sortRowReader
12✔
3394
                }
3395

3396
                var groupedRowReader *groupedRowReader
80✔
3397
                groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy)
80✔
3398
                if err != nil {
82✔
3399
                        return nil, err
2✔
3400
                }
2✔
3401
                rowReader = groupedRowReader
78✔
3402

78✔
3403
                if stmt.having != nil {
82✔
3404
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
3405
                }
4✔
3406
        }
3407

3408
        if len(scanSpecs.orderBySortExps) > 0 {
986✔
3409
                var sortRowReader *sortRowReader
49✔
3410
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
49✔
3411
                if err != nil {
50✔
3412
                        return nil, err
1✔
3413
                }
1✔
3414
                rowReader = sortRowReader
48✔
3415
        }
3416

3417
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets)
936✔
3418
        if err != nil {
939✔
3419
                return nil, err
3✔
3420
        }
3✔
3421
        rowReader = projectedRowReader
933✔
3422

933✔
3423
        if stmt.distinct {
940✔
3424
                var distinctRowReader *distinctRowReader
7✔
3425
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
7✔
3426
                if err != nil {
7✔
3427
                        return nil, err
×
3428
                }
×
3429
                rowReader = distinctRowReader
7✔
3430
        }
3431

3432
        if stmt.offset != nil {
983✔
3433
                var offset int
50✔
3434
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
3435
                if err != nil {
50✔
3436
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
3437
                }
×
3438

3439
                rowReader = newOffsetRowReader(rowReader, offset)
50✔
3440
        }
3441

3442
        if stmt.limit != nil {
1,030✔
3443
                var limit int
97✔
3444
                limit, err = evalExpAsInt(tx, stmt.limit, params)
97✔
3445
                if err != nil {
97✔
3446
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
3447
                }
×
3448

3449
                if limit < 0 {
97✔
3450
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
3451
                }
×
3452

3453
                if limit > 0 {
140✔
3454
                        rowReader = newLimitRowReader(rowReader, limit)
43✔
3455
                }
43✔
3456
        }
3457
        return rowReader, nil
933✔
3458
}
3459

3460
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
942✔
3461
        if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
948✔
3462
                if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
8✔
3463
                        return orderByExps, nil
2✔
3464
                }
2✔
3465

3466
                if ordExpsHasPrefix(groupByCols, orderByExps, stmt.Alias()) {
5✔
3467
                        for i := range orderByExps {
2✔
3468
                                groupByCols[i].descOrder = orderByExps[i].descOrder
1✔
3469
                        }
1✔
3470
                        return groupByCols, nil
1✔
3471
                }
3472
        }
3473
        return groupByCols, orderByExps
939✔
3474
}
3475

3476
func ordExpsHasPrefix(cols, prefix []*OrdExp, table string) bool {
10✔
3477
        if len(prefix) > len(cols) {
12✔
3478
                return false
2✔
3479
        }
2✔
3480

3481
        for i := range prefix {
17✔
3482
                ls := prefix[i].AsSelector()
9✔
3483
                rs := cols[i].AsSelector()
9✔
3484

9✔
3485
                if ls == nil || rs == nil {
9✔
3486
                        return false
×
3487
                }
×
3488

3489
                if EncodeSelector(ls.resolve(table)) != EncodeSelector(rs.resolve(table)) {
14✔
3490
                        return false
5✔
3491
                }
5✔
3492
        }
3493
        return true
3✔
3494
}
3495

3496
func (stmt *SelectStmt) groupByOrdExps() []*OrdExp {
959✔
3497
        groupByCols := stmt.groupBy
959✔
3498

959✔
3499
        ordExps := make([]*OrdExp, 0, len(groupByCols))
959✔
3500
        for _, col := range groupByCols {
1,006✔
3501
                ordExps = append(ordExps, &OrdExp{exp: col})
47✔
3502
        }
47✔
3503
        return ordExps
959✔
3504
}
3505

3506
func ordExpsHaveAggregations(exps []*OrdExp) bool {
7✔
3507
        for _, e := range exps {
17✔
3508
                if _, isAgg := e.exp.(*AggColSelector); isAgg {
11✔
3509
                        return true
1✔
3510
                }
1✔
3511
        }
3512
        return false
6✔
3513
}
3514

3515
func (stmt *SelectStmt) containsAggregations() bool {
1,559✔
3516
        for _, sel := range stmt.targetSelectors() {
3,155✔
3517
                _, isAgg := sel.(*AggColSelector)
1,596✔
3518
                if isAgg {
1,759✔
3519
                        return true
163✔
3520
                }
163✔
3521
        }
3522
        return false
1,396✔
3523
}
3524

3525
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
147✔
3526
        offset, err := exp.substitute(params)
147✔
3527
        if err != nil {
147✔
3528
                return 0, err
×
3529
        }
×
3530

3531
        texp, err := offset.reduce(tx, nil, "")
147✔
3532
        if err != nil {
147✔
3533
                return 0, err
×
3534
        }
×
3535

3536
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
147✔
3537
        if err != nil {
147✔
3538
                return 0, ErrInvalidValue
×
3539
        }
×
3540

3541
        num, ok := convVal.(int64)
147✔
3542
        if !ok {
147✔
3543
                return 0, ErrInvalidValue
×
3544
        }
×
3545

3546
        if num > math.MaxInt32 {
147✔
3547
                return 0, ErrInvalidValue
×
3548
        }
×
3549

3550
        return int(num), nil
147✔
3551
}
3552

3553
func (stmt *SelectStmt) Alias() string {
167✔
3554
        if stmt.as == "" {
333✔
3555
                return stmt.ds.Alias()
166✔
3556
        }
166✔
3557

3558
        return stmt.as
1✔
3559
}
3560

3561
func (stmt *SelectStmt) hasTxMetadata() bool {
876✔
3562
        for _, sel := range stmt.targetSelectors() {
1,692✔
3563
                switch s := sel.(type) {
816✔
3564
                case *ColSelector:
702✔
3565
                        if s.col == txMetadataCol {
703✔
3566
                                return true
1✔
3567
                        }
1✔
3568
                case *JSONSelector:
21✔
3569
                        if s.ColSelector.col == txMetadataCol {
24✔
3570
                                return true
3✔
3571
                        }
3✔
3572
                }
3573
        }
3574
        return false
872✔
3575
}
3576

3577
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
959✔
3578
        groupByCols, orderByCols := stmt.groupByOrdExps(), stmt.orderBy
959✔
3579

959✔
3580
        tableRef, isTableRef := stmt.ds.(*tableRef)
959✔
3581
        if !isTableRef {
1,025✔
3582
                groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
66✔
3583

66✔
3584
                return &ScanSpecs{
66✔
3585
                        groupBySortExps: groupByCols,
66✔
3586
                        orderBySortExps: orderByCols,
66✔
3587
                }, nil
66✔
3588
        }
66✔
3589

3590
        table, err := tableRef.referencedTable(tx)
893✔
3591
        if err != nil {
908✔
3592
                if tx.engine.tableResolveFor(tableRef.table) != nil {
16✔
3593
                        return &ScanSpecs{
1✔
3594
                                groupBySortExps: groupByCols,
1✔
3595
                                orderBySortExps: orderByCols,
1✔
3596
                        }, nil
1✔
3597
                }
1✔
3598
                return nil, err
14✔
3599
        }
3600

3601
        rangesByColID := make(map[uint32]*typedValueRange)
878✔
3602
        if stmt.where != nil {
1,396✔
3603
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
518✔
3604
                if err != nil {
520✔
3605
                        return nil, err
2✔
3606
                }
2✔
3607
        }
3608

3609
        preferredIndex, err := stmt.getPreferredIndex(table)
876✔
3610
        if err != nil {
876✔
3611
                return nil, err
×
3612
        }
×
3613

3614
        var sortingIndex *Index
876✔
3615
        if preferredIndex == nil {
1,722✔
3616
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
846✔
3617
        } else {
876✔
3618
                sortingIndex = preferredIndex
30✔
3619
        }
30✔
3620

3621
        if sortingIndex == nil {
1,629✔
3622
                sortingIndex = table.primaryIndex
753✔
3623
        }
753✔
3624

3625
        if tableRef.history && !sortingIndex.IsPrimary() {
876✔
3626
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3627
        }
×
3628

3629
        var descOrder bool
876✔
3630
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
893✔
3631
                groupByCols = nil
17✔
3632
        }
17✔
3633

3634
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
978✔
3635
                descOrder = orderByCols[0].descOrder
102✔
3636
                orderByCols = nil
102✔
3637
        }
102✔
3638

3639
        groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
876✔
3640

876✔
3641
        return &ScanSpecs{
876✔
3642
                Index:             sortingIndex,
876✔
3643
                rangesByColID:     rangesByColID,
876✔
3644
                IncludeHistory:    tableRef.history,
876✔
3645
                IncludeTxMetadata: stmt.hasTxMetadata(),
876✔
3646
                DescOrder:         descOrder,
876✔
3647
                groupBySortExps:   groupByCols,
876✔
3648
                orderBySortExps:   orderByCols,
876✔
3649
        }, nil
876✔
3650
}
3651

3652
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdExp, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
846✔
3653
        sortCols := groupByCols
846✔
3654
        if len(sortCols) == 0 {
1,666✔
3655
                sortCols = orderByCols
820✔
3656
        }
820✔
3657

3658
        if len(sortCols) == 0 {
1,553✔
3659
                return nil
707✔
3660
        }
707✔
3661

3662
        for _, idx := range table.indexes {
364✔
3663
                if idx.coversOrdCols(sortCols, rangesByColId) {
318✔
3664
                        return idx
93✔
3665
                }
93✔
3666
        }
3667
        return nil
46✔
3668
}
3669

3670
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
876✔
3671
        if len(stmt.indexOn) == 0 {
1,722✔
3672
                return nil, nil
846✔
3673
        }
846✔
3674

3675
        cols := make([]*Column, len(stmt.indexOn))
30✔
3676
        for i, colName := range stmt.indexOn {
80✔
3677
                col, err := table.GetColumnByName(colName)
50✔
3678
                if err != nil {
50✔
3679
                        return nil, err
×
3680
                }
×
3681

3682
                cols[i] = col
50✔
3683
        }
3684
        return table.GetIndexByName(indexName(table.name, cols))
30✔
3685
}
3686

3687
type UnionStmt struct {
3688
        distinct    bool
3689
        left, right DataSource
3690
}
3691

3692
func (stmt *UnionStmt) readOnly() bool {
1✔
3693
        return true
1✔
3694
}
1✔
3695

3696
func (stmt *UnionStmt) requiredPrivileges() []SQLPrivilege {
1✔
3697
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3698
}
1✔
3699

3700
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3701
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
3702
        if err != nil {
1✔
3703
                return err
×
3704
        }
×
3705
        return stmt.right.inferParameters(ctx, tx, params)
1✔
3706
}
3707

3708
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3709
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
3710
        if err != nil {
9✔
3711
                return tx, err
×
3712
        }
×
3713

3714
        return stmt.right.execAt(ctx, tx, params)
9✔
3715
}
3716

3717
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3718
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
11✔
3719
        if err != nil {
12✔
3720
                return nil, err
1✔
3721
        }
1✔
3722
        defer func() {
20✔
3723
                if err != nil {
14✔
3724
                        leftRowReader.Close()
4✔
3725
                }
4✔
3726
        }()
3727

3728
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
10✔
3729
        if err != nil {
11✔
3730
                return nil, err
1✔
3731
        }
1✔
3732
        defer func() {
18✔
3733
                if err != nil {
12✔
3734
                        rightRowReader.Close()
3✔
3735
                }
3✔
3736
        }()
3737

3738
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
9✔
3739
        if err != nil {
12✔
3740
                return nil, err
3✔
3741
        }
3✔
3742

3743
        return rowReader, nil
6✔
3744
}
3745

3746
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3747
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
11✔
3748
        if err != nil {
16✔
3749
                return nil, err
5✔
3750
        }
5✔
3751
        defer func() {
12✔
3752
                if err != nil {
7✔
3753
                        rowReader.Close()
1✔
3754
                }
1✔
3755
        }()
3756

3757
        if stmt.distinct {
11✔
3758
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
5✔
3759
                if err != nil {
6✔
3760
                        return nil, err
1✔
3761
                }
1✔
3762
                rowReader = distinctReader
4✔
3763
        }
3764

3765
        return rowReader, nil
5✔
3766
}
3767

3768
func (stmt *UnionStmt) Alias() string {
×
3769
        return ""
×
3770
}
×
3771

3772
func NewTableRef(table string, as string) *tableRef {
179✔
3773
        return &tableRef{
179✔
3774
                table: table,
179✔
3775
                as:    as,
179✔
3776
        }
179✔
3777
}
179✔
3778

3779
type tableRef struct {
3780
        table   string
3781
        history bool
3782
        period  period
3783
        as      string
3784
}
3785

3786
func (ref *tableRef) readOnly() bool {
1✔
3787
        return true
1✔
3788
}
1✔
3789

3790
func (ref *tableRef) requiredPrivileges() []SQLPrivilege {
1✔
3791
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3792
}
1✔
3793

3794
type period struct {
3795
        start *openPeriod
3796
        end   *openPeriod
3797
}
3798

3799
type openPeriod struct {
3800
        inclusive bool
3801
        instant   periodInstant
3802
}
3803

3804
type periodInstant struct {
3805
        exp         ValueExp
3806
        instantType instantType
3807
}
3808

3809
type instantType = int
3810

3811
const (
3812
        txInstant instantType = iota
3813
        timeInstant
3814
)
3815

3816
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
81✔
3817
        exp, err := i.exp.substitute(params)
81✔
3818
        if err != nil {
81✔
3819
                return 0, err
×
3820
        }
×
3821

3822
        instantVal, err := exp.reduce(tx, nil, "")
81✔
3823
        if err != nil {
83✔
3824
                return 0, err
2✔
3825
        }
2✔
3826

3827
        if i.instantType == txInstant {
124✔
3828
                txID, ok := instantVal.RawValue().(int64)
45✔
3829
                if !ok {
45✔
3830
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3831
                }
×
3832

3833
                if txID <= 0 {
52✔
3834
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3835
                }
7✔
3836

3837
                if inclusive {
61✔
3838
                        return uint64(txID), nil
23✔
3839
                }
23✔
3840

3841
                if asc {
26✔
3842
                        return uint64(txID + 1), nil
11✔
3843
                }
11✔
3844

3845
                if txID <= 1 {
5✔
3846
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3847
                }
1✔
3848

3849
                return uint64(txID - 1), nil
3✔
3850
        } else {
34✔
3851

34✔
3852
                var ts time.Time
34✔
3853

34✔
3854
                if instantVal.Type() == TimestampType {
67✔
3855
                        ts = instantVal.RawValue().(time.Time)
33✔
3856
                } else {
34✔
3857
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
3858
                        if err != nil {
1✔
3859
                                return 0, err
×
3860
                        }
×
3861

3862
                        tval, err := conv(instantVal)
1✔
3863
                        if err != nil {
1✔
3864
                                return 0, err
×
3865
                        }
×
3866

3867
                        ts = tval.RawValue().(time.Time)
1✔
3868
                }
3869

3870
                sts := ts
34✔
3871

34✔
3872
                if asc {
57✔
3873
                        if !inclusive {
34✔
3874
                                sts = sts.Add(1 * time.Second)
11✔
3875
                        }
11✔
3876

3877
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
23✔
3878
                        if err != nil {
34✔
3879
                                return 0, err
11✔
3880
                        }
11✔
3881

3882
                        return txHdr.ID, nil
12✔
3883
                }
3884

3885
                if !inclusive {
11✔
3886
                        sts = sts.Add(-1 * time.Second)
×
3887
                }
×
3888

3889
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
3890
                if err != nil {
11✔
3891
                        return 0, err
×
3892
                }
×
3893

3894
                return txHdr.ID, nil
11✔
3895
        }
3896
}
3897

3898
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
3,943✔
3899
        table, err := tx.catalog.GetTableByName(stmt.table)
3,943✔
3900
        if err != nil {
3,963✔
3901
                return nil, err
20✔
3902
        }
20✔
3903
        return table, nil
3,923✔
3904
}
3905

3906
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3907
        return nil
1✔
3908
}
1✔
3909

3910
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3911
        return tx, nil
×
3912
}
×
3913

3914
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
906✔
3915
        if tx == nil {
906✔
3916
                return nil, ErrIllegalArguments
×
3917
        }
×
3918

3919
        table, err := stmt.referencedTable(tx)
906✔
3920
        if err == nil {
1,811✔
3921
                return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
905✔
3922
        }
905✔
3923

3924
        if resolver := tx.engine.tableResolveFor(stmt.table); resolver != nil {
2✔
3925
                return resolver.Resolve(ctx, tx, stmt.Alias())
1✔
3926
        }
1✔
3927
        return nil, err
×
3928
}
3929

3930
func (stmt *tableRef) Alias() string {
683✔
3931
        if stmt.as == "" {
1,206✔
3932
                return stmt.table
523✔
3933
        }
523✔
3934
        return stmt.as
160✔
3935
}
3936

3937
type valuesDataSource struct {
3938
        inferTypes bool
3939
        rows       []*RowSpec
3940
}
3941

3942
func NewValuesDataSource(rows []*RowSpec) *valuesDataSource {
120✔
3943
        return &valuesDataSource{
120✔
3944
                rows: rows,
120✔
3945
        }
120✔
3946
}
120✔
3947

3948
func (ds *valuesDataSource) readOnly() bool {
×
3949
        return true
×
3950
}
×
3951

3952
func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege {
97✔
3953
        return nil
97✔
3954
}
97✔
3955

3956
func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3957
        return tx, nil
×
3958
}
×
3959

3960
func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
3961
        return nil
×
3962
}
×
3963

3964
func (ds *valuesDataSource) Alias() string {
×
3965
        return ""
×
3966
}
×
3967

3968
func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
2,123✔
3969
        if tx == nil {
2,123✔
3970
                return nil, ErrIllegalArguments
×
3971
        }
×
3972

3973
        cols := make([]ColDescriptor, len(ds.rows[0].Values))
2,123✔
3974
        for i := range cols {
9,775✔
3975
                cols[i] = ColDescriptor{
7,652✔
3976
                        Type:   AnyType,
7,652✔
3977
                        Column: fmt.Sprintf("col%d", i),
7,652✔
3978
                }
7,652✔
3979
        }
7,652✔
3980

3981
        emptyColsDesc, emptyParams := map[string]ColDescriptor{}, map[string]string{}
2,123✔
3982

2,123✔
3983
        if ds.inferTypes {
2,132✔
3984
                for i := 0; i < len(cols); i++ {
48✔
3985
                        t := AnyType
39✔
3986
                        for j := 0; j < len(ds.rows); j++ {
146✔
3987
                                e, err := ds.rows[j].Values[i].substitute(params)
107✔
3988
                                if err != nil {
107✔
3989
                                        return nil, err
×
3990
                                }
×
3991

3992
                                it, err := e.inferType(emptyColsDesc, emptyParams, "")
107✔
3993
                                if err != nil {
107✔
3994
                                        return nil, err
×
3995
                                }
×
3996

3997
                                if t == AnyType {
146✔
3998
                                        t = it
39✔
3999
                                } else if t != it && it != AnyType {
109✔
4000
                                        return nil, fmt.Errorf("cannot match types %s and %s", t, it)
2✔
4001
                                }
2✔
4002
                        }
4003
                        cols[i].Type = t
37✔
4004
                }
4005
        }
4006

4007
        values := make([][]ValueExp, len(ds.rows))
2,121✔
4008
        for i, rowSpec := range ds.rows {
4,351✔
4009
                values[i] = rowSpec.Values
2,230✔
4010
        }
2,230✔
4011
        return NewValuesRowReader(tx, params, cols, ds.inferTypes, "values", values)
2,121✔
4012
}
4013

4014
type JoinSpec struct {
4015
        joinType JoinType
4016
        ds       DataSource
4017
        cond     ValueExp
4018
        indexOn  []string
4019
}
4020

4021
type OrdExp struct {
4022
        exp       ValueExp
4023
        descOrder bool
4024
}
4025

4026
func (oc *OrdExp) AsSelector() Selector {
714✔
4027
        sel, ok := oc.exp.(Selector)
714✔
4028
        if ok {
1,374✔
4029
                return sel
660✔
4030
        }
660✔
4031
        return nil
54✔
4032
}
4033

4034
func NewOrdCol(table string, col string, descOrder bool) *OrdExp {
1✔
4035
        return &OrdExp{
1✔
4036
                exp:       NewColSelector(table, col),
1✔
4037
                descOrder: descOrder,
1✔
4038
        }
1✔
4039
}
1✔
4040

4041
type Selector interface {
4042
        ValueExp
4043
        resolve(implicitTable string) (aggFn, table, col string)
4044
}
4045

4046
type ColSelector struct {
4047
        table string
4048
        col   string
4049
}
4050

4051
func NewColSelector(table, col string) *ColSelector {
126✔
4052
        return &ColSelector{
126✔
4053
                table: table,
126✔
4054
                col:   col,
126✔
4055
        }
126✔
4056
}
126✔
4057

4058
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
867,655✔
4059
        table = implicitTable
867,655✔
4060
        if sel.table != "" {
1,168,667✔
4061
                table = sel.table
301,012✔
4062
        }
301,012✔
4063
        return "", table, sel.col
867,655✔
4064
}
4065

4066
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
685✔
4067
        _, table, col := sel.resolve(implicitTable)
685✔
4068
        encSel := EncodeSelector("", table, col)
685✔
4069

685✔
4070
        desc, ok := cols[encSel]
685✔
4071
        if !ok {
688✔
4072
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
4073
        }
3✔
4074
        return desc.Type, nil
682✔
4075
}
4076

4077
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
4078
        _, table, col := sel.resolve(implicitTable)
15✔
4079
        encSel := EncodeSelector("", table, col)
15✔
4080

15✔
4081
        desc, ok := cols[encSel]
15✔
4082
        if !ok {
17✔
4083
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
4084
        }
2✔
4085

4086
        if desc.Type != t {
16✔
4087
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
4088
        }
3✔
4089

4090
        return nil
10✔
4091
}
4092

4093
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
161,866✔
4094
        return sel, nil
161,866✔
4095
}
161,866✔
4096

4097
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
715,143✔
4098
        if row == nil {
715,144✔
4099
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
4100
        }
1✔
4101

4102
        aggFn, table, col := sel.resolve(implicitTable)
715,142✔
4103

715,142✔
4104
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
715,142✔
4105
        if !ok {
715,149✔
4106
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
4107
        }
7✔
4108
        return v, nil
715,135✔
4109
}
4110

4111
func (sel *ColSelector) selectors() []Selector {
926✔
4112
        return []Selector{sel}
926✔
4113
}
926✔
4114

4115
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
568✔
4116
        aggFn, table, col := sel.resolve(implicitTable)
568✔
4117

568✔
4118
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
568✔
4119
        if !ok {
846✔
4120
                return sel
278✔
4121
        }
278✔
4122

4123
        return v
290✔
4124
}
4125

4126
func (sel *ColSelector) isConstant() bool {
12✔
4127
        return false
12✔
4128
}
12✔
4129

4130
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
4131
        return nil
11✔
4132
}
11✔
4133

4134
func (sel *ColSelector) String() string {
48✔
4135
        return sel.col
48✔
4136
}
48✔
4137

4138
type AggColSelector struct {
4139
        aggFn AggregateFn
4140
        table string
4141
        col   string
4142
}
4143

4144
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
4145
        return &AggColSelector{
16✔
4146
                aggFn: aggFn,
16✔
4147
                table: table,
16✔
4148
                col:   col,
16✔
4149
        }
16✔
4150
}
16✔
4151

4152
func EncodeSelector(aggFn, table, col string) string {
1,417,490✔
4153
        return aggFn + "(" + table + "." + col + ")"
1,417,490✔
4154
}
1,417,490✔
4155

4156
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,586✔
4157
        table = implicitTable
1,586✔
4158
        if sel.table != "" {
1,717✔
4159
                table = sel.table
131✔
4160
        }
131✔
4161
        return sel.aggFn, table, sel.col
1,586✔
4162
}
4163

4164
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
36✔
4165
        if sel.aggFn == COUNT {
55✔
4166
                return IntegerType, nil
19✔
4167
        }
19✔
4168

4169
        colSelector := &ColSelector{table: sel.table, col: sel.col}
17✔
4170

17✔
4171
        if sel.aggFn == SUM || sel.aggFn == AVG {
24✔
4172
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
4173
                if err != nil {
7✔
4174
                        return AnyType, err
×
4175
                }
×
4176

4177
                if t != IntegerType && t != Float64Type {
7✔
4178
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
4179

×
4180
                }
×
4181

4182
                return t, nil
7✔
4183
        }
4184

4185
        return colSelector.inferType(cols, params, implicitTable)
10✔
4186
}
4187

4188
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
4189
        if sel.aggFn == COUNT {
10✔
4190
                if t != IntegerType {
3✔
4191
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4192
                }
1✔
4193
                return nil
1✔
4194
        }
4195

4196
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
4197

6✔
4198
        if sel.aggFn == SUM || sel.aggFn == AVG {
10✔
4199
                if t != IntegerType && t != Float64Type {
5✔
4200
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
4201
                }
1✔
4202
        }
4203

4204
        return colSelector.requiresType(t, cols, params, implicitTable)
5✔
4205
}
4206

4207
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
412✔
4208
        return sel, nil
412✔
4209
}
412✔
4210

4211
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
459✔
4212
        if row == nil {
460✔
4213
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
4214
        }
1✔
4215

4216
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
458✔
4217
        if !ok {
459✔
4218
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
4219
        }
1✔
4220
        return v, nil
457✔
4221
}
4222

4223
func (sel *AggColSelector) selectors() []Selector {
232✔
4224
        return []Selector{sel}
232✔
4225
}
232✔
4226

4227
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4228
        return sel
×
4229
}
×
4230

4231
func (sel *AggColSelector) isConstant() bool {
1✔
4232
        return false
1✔
4233
}
1✔
4234

4235
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4236
        return nil
×
4237
}
×
4238

4239
func (sel *AggColSelector) String() string {
×
4240
        return sel.aggFn + "(" + sel.col + ")"
×
4241
}
×
4242

4243
type NumExp struct {
4244
        op          NumOperator
4245
        left, right ValueExp
4246
}
4247

4248
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
17✔
4249
        // First step - check if we can infer the type of sub-expressions
17✔
4250
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
17✔
4251
        if err != nil {
17✔
4252
                return AnyType, err
×
4253
        }
×
4254
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType {
17✔
4255
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
4256
        }
×
4257

4258
        tright, err := bexp.right.inferType(cols, params, implicitTable)
17✔
4259
        if err != nil {
17✔
4260
                return AnyType, err
×
4261
        }
×
4262
        if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType {
19✔
4263
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
4264
        }
2✔
4265

4266
        if tleft == IntegerType && tright == IntegerType {
19✔
4267
                // Both sides are integer types - the result is also integer
4✔
4268
                return IntegerType, nil
4✔
4269
        }
4✔
4270

4271
        if tleft != AnyType && tright != AnyType {
20✔
4272
                // Both sides have concrete types but at least one of them is float
9✔
4273
                return Float64Type, nil
9✔
4274
        }
9✔
4275

4276
        // Both sides are ambiguous
4277
        return AnyType, nil
2✔
4278
}
4279

4280
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
4281
        ret := make(map[string]SQLValueType, len(params))
11✔
4282
        for k, v := range params {
15✔
4283
                ret[k] = v
4✔
4284
        }
4✔
4285
        return ret
11✔
4286
}
4287

4288
func restoreParams(params, restore map[string]SQLValueType) {
2✔
4289
        for k := range params {
2✔
4290
                delete(params, k)
×
4291
        }
×
4292
        for k, v := range restore {
2✔
4293
                params[k] = v
×
4294
        }
×
4295
}
4296

4297
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4298
        if t != IntegerType && t != Float64Type {
8✔
4299
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4300
        }
1✔
4301

4302
        floatArgs := 2
6✔
4303
        paramsOrig := copyParams(params)
6✔
4304
        err := bexp.left.requiresType(t, cols, params, implicitTable)
6✔
4305
        if err != nil && t == Float64Type {
7✔
4306
                restoreParams(params, paramsOrig)
1✔
4307
                floatArgs--
1✔
4308
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
4309
        }
1✔
4310
        if err != nil {
7✔
4311
                return err
1✔
4312
        }
1✔
4313

4314
        paramsOrig = copyParams(params)
5✔
4315
        err = bexp.right.requiresType(t, cols, params, implicitTable)
5✔
4316
        if err != nil && t == Float64Type {
6✔
4317
                restoreParams(params, paramsOrig)
1✔
4318
                floatArgs--
1✔
4319
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
4320
        }
1✔
4321
        if err != nil {
7✔
4322
                return err
2✔
4323
        }
2✔
4324

4325
        if t == Float64Type && floatArgs == 0 {
3✔
4326
                // Currently this case requires explicit float cast
×
4327
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
4328
        }
×
4329

4330
        return nil
3✔
4331
}
4332

4333
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
187✔
4334
        rlexp, err := bexp.left.substitute(params)
187✔
4335
        if err != nil {
187✔
4336
                return nil, err
×
4337
        }
×
4338

4339
        rrexp, err := bexp.right.substitute(params)
187✔
4340
        if err != nil {
187✔
4341
                return nil, err
×
4342
        }
×
4343

4344
        bexp.left = rlexp
187✔
4345
        bexp.right = rrexp
187✔
4346

187✔
4347
        return bexp, nil
187✔
4348
}
4349

4350
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
124,242✔
4351
        vl, err := bexp.left.reduce(tx, row, implicitTable)
124,242✔
4352
        if err != nil {
124,242✔
4353
                return nil, err
×
4354
        }
×
4355

4356
        vr, err := bexp.right.reduce(tx, row, implicitTable)
124,242✔
4357
        if err != nil {
124,242✔
4358
                return nil, err
×
4359
        }
×
4360

4361
        vl = unwrapJSON(vl)
124,242✔
4362
        vr = unwrapJSON(vr)
124,242✔
4363

124,242✔
4364
        return applyNumOperator(bexp.op, vl, vr)
124,242✔
4365
}
4366

4367
func unwrapJSON(v TypedValue) TypedValue {
248,484✔
4368
        if jsonVal, ok := v.(*JSON); ok {
248,584✔
4369
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
200✔
4370
                        return sv
100✔
4371
                }
100✔
4372
        }
4373
        return v
248,384✔
4374
}
4375

4376
func (bexp *NumExp) selectors() []Selector {
13✔
4377
        return append(bexp.left.selectors(), bexp.right.selectors()...)
13✔
4378
}
13✔
4379

4380
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4381
        return &NumExp{
1✔
4382
                op:    bexp.op,
1✔
4383
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
4384
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
4385
        }
1✔
4386
}
1✔
4387

4388
func (bexp *NumExp) isConstant() bool {
5✔
4389
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
4390
}
5✔
4391

4392
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
4✔
4393
        return nil
4✔
4394
}
4✔
4395

4396
func (bexp *NumExp) String() string {
18✔
4397
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
18✔
4398
}
18✔
4399

4400
type NotBoolExp struct {
4401
        exp ValueExp
4402
}
4403

4404
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4405
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
1✔
4406
        if err != nil {
1✔
4407
                return AnyType, err
×
4408
        }
×
4409

4410
        return BooleanType, nil
1✔
4411
}
4412

4413
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
4414
        if t != BooleanType {
7✔
4415
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4416
        }
1✔
4417

4418
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
5✔
4419
}
4420

4421
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
4422
        rexp, err := bexp.exp.substitute(params)
22✔
4423
        if err != nil {
22✔
4424
                return nil, err
×
4425
        }
×
4426

4427
        bexp.exp = rexp
22✔
4428

22✔
4429
        return bexp, nil
22✔
4430
}
4431

4432
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
4433
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
4434
        if err != nil {
22✔
4435
                return nil, err
×
4436
        }
×
4437

4438
        r, isBool := v.RawValue().(bool)
22✔
4439
        if !isBool {
22✔
4440
                return nil, ErrInvalidCondition
×
4441
        }
×
4442

4443
        return &Bool{val: !r}, nil
22✔
4444
}
4445

4446
func (bexp *NotBoolExp) selectors() []Selector {
×
4447
        return bexp.exp.selectors()
×
4448
}
×
4449

4450
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4451
        return &NotBoolExp{
×
4452
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
4453
        }
×
4454
}
×
4455

4456
func (bexp *NotBoolExp) isConstant() bool {
1✔
4457
        return bexp.exp.isConstant()
1✔
4458
}
1✔
4459

4460
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
4461
        return nil
7✔
4462
}
7✔
4463

4464
func (bexp *NotBoolExp) String() string {
12✔
4465
        return fmt.Sprintf("(NOT %s)", bexp.exp.String())
12✔
4466
}
12✔
4467

4468
type LikeBoolExp struct {
4469
        val     ValueExp
4470
        notLike bool
4471
        pattern ValueExp
4472
}
4473

4474
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
4✔
4475
        return &LikeBoolExp{
4✔
4476
                val:     val,
4✔
4477
                notLike: notLike,
4✔
4478
                pattern: pattern,
4✔
4479
        }
4✔
4480
}
4✔
4481

4482
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4483
        if bexp.val == nil || bexp.pattern == nil {
4✔
4484
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4485
        }
1✔
4486

4487
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
2✔
4488
        if err != nil {
3✔
4489
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4490
        }
1✔
4491

4492
        return BooleanType, nil
1✔
4493
}
4494

4495
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4496
        if bexp.val == nil || bexp.pattern == nil {
9✔
4497
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
4498
        }
2✔
4499

4500
        if t != BooleanType {
7✔
4501
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
4502
        }
2✔
4503

4504
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
4505
        if err != nil {
4✔
4506
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4507
        }
1✔
4508

4509
        return nil
2✔
4510
}
4511

4512
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
135✔
4513
        if bexp.val == nil || bexp.pattern == nil {
136✔
4514
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4515
        }
1✔
4516

4517
        val, err := bexp.val.substitute(params)
134✔
4518
        if err != nil {
134✔
4519
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4520
        }
×
4521

4522
        pattern, err := bexp.pattern.substitute(params)
134✔
4523
        if err != nil {
134✔
4524
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4525
        }
×
4526

4527
        return &LikeBoolExp{
134✔
4528
                val:     val,
134✔
4529
                notLike: bexp.notLike,
134✔
4530
                pattern: pattern,
134✔
4531
        }, nil
134✔
4532
}
4533

4534
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
141✔
4535
        if bexp.val == nil || bexp.pattern == nil {
142✔
4536
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4537
        }
1✔
4538

4539
        rval, err := bexp.val.reduce(tx, row, implicitTable)
140✔
4540
        if err != nil {
140✔
4541
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4542
        }
×
4543

4544
        if rval.IsNull() {
141✔
4545
                return &Bool{val: bexp.notLike}, nil
1✔
4546
        }
1✔
4547

4548
        rvalStr, ok := rval.RawValue().(string)
139✔
4549
        if !ok {
140✔
4550
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
4551
        }
1✔
4552

4553
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
138✔
4554
        if err != nil {
138✔
4555
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4556
        }
×
4557

4558
        if rpattern.Type() != VarcharType {
138✔
4559
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
4560
        }
×
4561

4562
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
138✔
4563
        if err != nil {
138✔
4564
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4565
        }
×
4566

4567
        return &Bool{val: matched != bexp.notLike}, nil
138✔
4568
}
4569

4570
func (bexp *LikeBoolExp) selectors() []Selector {
1✔
4571
        return bexp.val.selectors()
1✔
4572
}
1✔
4573

4574
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4575
        return bexp
1✔
4576
}
1✔
4577

4578
func (bexp *LikeBoolExp) isConstant() bool {
2✔
4579
        return false
2✔
4580
}
2✔
4581

4582
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
4583
        return nil
8✔
4584
}
8✔
4585

4586
func (bexp *LikeBoolExp) String() string {
5✔
4587
        fmtStr := "(%s LIKE %s)"
5✔
4588
        if bexp.notLike {
6✔
4589
                fmtStr = "(%s NOT LIKE %s)"
1✔
4590
        }
1✔
4591
        return fmt.Sprintf(fmtStr, bexp.val.String(), bexp.pattern.String())
5✔
4592
}
4593

4594
type CmpBoolExp struct {
4595
        op          CmpOperator
4596
        left, right ValueExp
4597
}
4598

4599
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
67✔
4600
        return &CmpBoolExp{
67✔
4601
                op:    op,
67✔
4602
                left:  left,
67✔
4603
                right: right,
67✔
4604
        }
67✔
4605
}
67✔
4606

4607
func (bexp *CmpBoolExp) Left() ValueExp {
×
4608
        return bexp.left
×
4609
}
×
4610

4611
func (bexp *CmpBoolExp) Right() ValueExp {
×
4612
        return bexp.right
×
4613
}
×
4614

4615
func (bexp *CmpBoolExp) OP() CmpOperator {
×
4616
        return bexp.op
×
4617
}
×
4618

4619
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
63✔
4620
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
63✔
4621
        if err != nil {
63✔
4622
                return AnyType, err
×
4623
        }
×
4624

4625
        tright, err := bexp.right.inferType(cols, params, implicitTable)
63✔
4626
        if err != nil {
65✔
4627
                return AnyType, err
2✔
4628
        }
2✔
4629

4630
        // unification step
4631

4632
        if tleft == tright {
74✔
4633
                return BooleanType, nil
13✔
4634
        }
13✔
4635

4636
        _, ok := coerceTypes(tleft, tright)
48✔
4637
        if !ok {
52✔
4638
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4639
        }
4✔
4640

4641
        if tleft == AnyType {
47✔
4642
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
3✔
4643
                if err != nil {
3✔
4644
                        return AnyType, err
×
4645
                }
×
4646
        }
4647

4648
        if tright == AnyType {
84✔
4649
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
40✔
4650
                if err != nil {
41✔
4651
                        return AnyType, err
1✔
4652
                }
1✔
4653
        }
4654
        return BooleanType, nil
43✔
4655
}
4656

4657
func coerceTypes(t1, t2 SQLValueType) (SQLValueType, bool) {
48✔
4658
        switch {
48✔
4659
        case t1 == t2:
×
4660
                return t1, true
×
4661
        case t1 == AnyType:
3✔
4662
                return t2, true
3✔
4663
        case t2 == AnyType:
40✔
4664
                return t1, true
40✔
4665
        case (t1 == IntegerType && t2 == Float64Type) ||
4666
                (t1 == Float64Type && t2 == IntegerType):
1✔
4667
                return Float64Type, true
1✔
4668
        }
4669
        return "", false
4✔
4670
}
4671

4672
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
41✔
4673
        if t != BooleanType {
42✔
4674
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4675
        }
1✔
4676

4677
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4678
        return err
40✔
4679
}
4680

4681
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
14,324✔
4682
        rlexp, err := bexp.left.substitute(params)
14,324✔
4683
        if err != nil {
14,324✔
4684
                return nil, err
×
4685
        }
×
4686

4687
        rrexp, err := bexp.right.substitute(params)
14,324✔
4688
        if err != nil {
14,325✔
4689
                return nil, err
1✔
4690
        }
1✔
4691

4692
        bexp.left = rlexp
14,323✔
4693
        bexp.right = rrexp
14,323✔
4694

14,323✔
4695
        return bexp, nil
14,323✔
4696
}
4697

4698
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
13,920✔
4699
        vl, err := bexp.left.reduce(tx, row, implicitTable)
13,920✔
4700
        if err != nil {
13,923✔
4701
                return nil, err
3✔
4702
        }
3✔
4703

4704
        vr, err := bexp.right.reduce(tx, row, implicitTable)
13,917✔
4705
        if err != nil {
13,919✔
4706
                return nil, err
2✔
4707
        }
2✔
4708

4709
        r, err := vl.Compare(vr)
13,915✔
4710
        if err != nil {
13,919✔
4711
                return nil, err
4✔
4712
        }
4✔
4713

4714
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
13,911✔
4715
}
4716

4717
func (bexp *CmpBoolExp) selectors() []Selector {
12✔
4718
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12✔
4719
}
12✔
4720

4721
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
282✔
4722
        return &CmpBoolExp{
282✔
4723
                op:    bexp.op,
282✔
4724
                left:  bexp.left.reduceSelectors(row, implicitTable),
282✔
4725
                right: bexp.right.reduceSelectors(row, implicitTable),
282✔
4726
        }
282✔
4727
}
282✔
4728

4729
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4730
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4731
}
2✔
4732

4733
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
607✔
4734
        matchingFunc := func(_, right ValueExp) (*ColSelector, ValueExp, bool) {
1,423✔
4735
                s, isSel := bexp.left.(*ColSelector)
816✔
4736
                if isSel && s.col != revCol && bexp.right.isConstant() {
1,214✔
4737
                        return s, right, true
398✔
4738
                }
398✔
4739
                return nil, nil, false
418✔
4740
        }
4741

4742
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
607✔
4743
        if !ok {
816✔
4744
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
209✔
4745
        }
209✔
4746

4747
        if !ok {
816✔
4748
                return nil
209✔
4749
        }
209✔
4750

4751
        aggFn, t, col := sel.resolve(table.name)
398✔
4752
        if aggFn != "" || t != asTable {
412✔
4753
                return nil
14✔
4754
        }
14✔
4755

4756
        column, err := table.GetColumnByName(col)
384✔
4757
        if err != nil {
385✔
4758
                return err
1✔
4759
        }
1✔
4760

4761
        val, err := c.substitute(params)
383✔
4762
        if errors.Is(err, ErrMissingParameter) {
442✔
4763
                // TODO: not supported when parameters are not provided during query resolution
59✔
4764
                return nil
59✔
4765
        }
59✔
4766
        if err != nil {
324✔
4767
                return err
×
4768
        }
×
4769

4770
        rval, err := val.reduce(nil, nil, table.name)
324✔
4771
        if err != nil {
325✔
4772
                return err
1✔
4773
        }
1✔
4774

4775
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
323✔
4776
}
4777

4778
func (bexp *CmpBoolExp) String() string {
20✔
4779
        opStr := CmpOperatorToString(bexp.op)
20✔
4780
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
20✔
4781
}
20✔
4782

4783
type TimestampFieldType string
4784

4785
const (
4786
        TimestampFieldTypeYear   TimestampFieldType = "YEAR"
4787
        TimestampFieldTypeMonth  TimestampFieldType = "MONTH"
4788
        TimestampFieldTypeDay    TimestampFieldType = "DAY"
4789
        TimestampFieldTypeHour   TimestampFieldType = "HOUR"
4790
        TimestampFieldTypeMinute TimestampFieldType = "MINUTE"
4791
        TimestampFieldTypeSecond TimestampFieldType = "SECOND"
4792
)
4793

4794
type ExtractFromTimestampExp struct {
4795
        Field TimestampFieldType
4796
        Exp   ValueExp
4797
}
4798

NEW
4799
func (te *ExtractFromTimestampExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
×
NEW
4800
        inferredType, err := te.Exp.inferType(cols, params, implicitTable)
×
NEW
4801
        if err != nil {
×
NEW
4802
                return "", err
×
NEW
4803
        }
×
4804

NEW
4805
        if inferredType != TimestampType &&
×
NEW
4806
                inferredType != VarcharType &&
×
NEW
4807
                inferredType != AnyType {
×
NEW
4808
                return "", fmt.Errorf("timestamp expression must be of type %v or %v, but was: %v", TimestampType, VarcharType, inferredType)
×
NEW
4809
        }
×
NEW
4810
        return IntegerType, nil
×
4811
}
4812

NEW
4813
func (te *ExtractFromTimestampExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
NEW
4814
        if t != IntegerType && t != Float64Type {
×
NEW
4815
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
×
NEW
4816
        }
×
NEW
4817
        return te.Exp.requiresType(TimestampType, cols, params, implicitTable)
×
4818
}
4819

4820
func (te *ExtractFromTimestampExp) substitute(params map[string]interface{}) (ValueExp, error) {
18✔
4821
        exp, err := te.Exp.substitute(params)
18✔
4822
        if err != nil {
18✔
NEW
4823
                return nil, err
×
NEW
4824
        }
×
4825
        return &ExtractFromTimestampExp{
18✔
4826
                Field: te.Field,
18✔
4827
                Exp:   exp,
18✔
4828
        }, nil
18✔
4829
}
4830

4831
func (te *ExtractFromTimestampExp) selectors() []Selector {
12✔
4832
        return te.Exp.selectors()
12✔
4833
}
12✔
4834

4835
func (te *ExtractFromTimestampExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
18✔
4836
        v, err := te.Exp.reduce(tx, row, implicitTable)
18✔
4837
        if err != nil {
18✔
NEW
4838
                return nil, err
×
NEW
4839
        }
×
4840

4841
        if v.IsNull() {
18✔
NEW
4842
                return NewNull(IntegerType), nil
×
NEW
4843
        }
×
4844

4845
        if t := v.Type(); t != TimestampType && t != VarcharType {
18✔
NEW
4846
                return nil, fmt.Errorf("%w: expected type %v but found type %v", ErrInvalidTypes, TimestampType, t)
×
NEW
4847
        }
×
4848

4849
        if v.Type() == VarcharType {
22✔
4850
                converterFunc, err := getConverter(VarcharType, TimestampType)
4✔
4851
                if err != nil {
4✔
NEW
4852
                        return nil, err
×
NEW
4853
                }
×
4854
                casted, err := converterFunc(v)
4✔
4855
                if err != nil {
4✔
NEW
4856
                        return nil, err
×
NEW
4857
                }
×
4858
                v = casted
4✔
4859
        }
4860

4861
        t, _ := v.RawValue().(time.Time)
18✔
4862

18✔
4863
        year, month, day := t.Date()
18✔
4864

18✔
4865
        switch te.Field {
18✔
4866
        case TimestampFieldTypeYear:
3✔
4867
                return NewInteger(int64(year)), nil
3✔
4868
        case TimestampFieldTypeMonth:
3✔
4869
                return NewInteger(int64(month)), nil
3✔
4870
        case TimestampFieldTypeDay:
3✔
4871
                return NewInteger(int64(day)), nil
3✔
4872
        case TimestampFieldTypeHour:
3✔
4873
                return NewInteger(int64(t.Hour())), nil
3✔
4874
        case TimestampFieldTypeMinute:
3✔
4875
                return NewInteger(int64(t.Minute())), nil
3✔
4876
        case TimestampFieldTypeSecond:
3✔
4877
                return NewInteger(int64(t.Second())), nil
3✔
4878
        }
NEW
4879
        return nil, fmt.Errorf("unknown timestamp field type: %s", te.Field)
×
4880
}
4881

NEW
4882
func (te *ExtractFromTimestampExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
NEW
4883
        return &ExtractFromTimestampExp{
×
NEW
4884
                Field: te.Field,
×
NEW
4885
                Exp:   te.Exp.reduceSelectors(row, implicitTable),
×
NEW
4886
        }
×
NEW
4887
}
×
4888

4889
func (te *ExtractFromTimestampExp) isConstant() bool {
1✔
4890
        return false
1✔
4891
}
1✔
4892

NEW
4893
func (te *ExtractFromTimestampExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
NEW
4894
        return nil
×
NEW
4895
}
×
4896

4897
func (te *ExtractFromTimestampExp) String() string {
6✔
4898
        return fmt.Sprintf("EXTRACT(%s FROM %s)", te.Field, te.Exp)
6✔
4899
}
6✔
4900

4901
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
323✔
4902
        currRange, ranged := rangesByColID[colID]
323✔
4903
        var newRange *typedValueRange
323✔
4904

323✔
4905
        switch cmp {
323✔
4906
        case EQ:
250✔
4907
                {
500✔
4908
                        newRange = &typedValueRange{
250✔
4909
                                lRange: &typedValueSemiRange{
250✔
4910
                                        val:       val,
250✔
4911
                                        inclusive: true,
250✔
4912
                                },
250✔
4913
                                hRange: &typedValueSemiRange{
250✔
4914
                                        val:       val,
250✔
4915
                                        inclusive: true,
250✔
4916
                                },
250✔
4917
                        }
250✔
4918
                }
250✔
4919
        case LT:
13✔
4920
                {
26✔
4921
                        newRange = &typedValueRange{
13✔
4922
                                hRange: &typedValueSemiRange{
13✔
4923
                                        val: val,
13✔
4924
                                },
13✔
4925
                        }
13✔
4926
                }
13✔
4927
        case LE:
12✔
4928
                {
24✔
4929
                        newRange = &typedValueRange{
12✔
4930
                                hRange: &typedValueSemiRange{
12✔
4931
                                        val:       val,
12✔
4932
                                        inclusive: true,
12✔
4933
                                },
12✔
4934
                        }
12✔
4935
                }
12✔
4936
        case GT:
18✔
4937
                {
36✔
4938
                        newRange = &typedValueRange{
18✔
4939
                                lRange: &typedValueSemiRange{
18✔
4940
                                        val: val,
18✔
4941
                                },
18✔
4942
                        }
18✔
4943
                }
18✔
4944
        case GE:
18✔
4945
                {
36✔
4946
                        newRange = &typedValueRange{
18✔
4947
                                lRange: &typedValueSemiRange{
18✔
4948
                                        val:       val,
18✔
4949
                                        inclusive: true,
18✔
4950
                                },
18✔
4951
                        }
18✔
4952
                }
18✔
4953
        case NE:
12✔
4954
                {
24✔
4955
                        return nil
12✔
4956
                }
12✔
4957
        }
4958

4959
        if !ranged {
617✔
4960
                rangesByColID[colID] = newRange
306✔
4961
                return nil
306✔
4962
        }
306✔
4963

4964
        return currRange.refineWith(newRange)
5✔
4965
}
4966

4967
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
13,911✔
4968
        switch {
13,911✔
4969
        case cmp == 0:
1,167✔
4970
                {
2,334✔
4971
                        return op == EQ || op == LE || op == GE
1,167✔
4972
                }
1,167✔
4973
        case cmp < 0:
6,488✔
4974
                {
12,976✔
4975
                        return op == NE || op == LT || op == LE
6,488✔
4976
                }
6,488✔
4977
        case cmp > 0:
6,256✔
4978
                {
12,512✔
4979
                        return op == NE || op == GT || op == GE
6,256✔
4980
                }
6,256✔
4981
        }
4982
        return false
×
4983
}
4984

4985
type BinBoolExp struct {
4986
        op          LogicOperator
4987
        left, right ValueExp
4988
}
4989

4990
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
18✔
4991
        bexp := &BinBoolExp{
18✔
4992
                op: op,
18✔
4993
        }
18✔
4994

18✔
4995
        bexp.left = lrexp
18✔
4996
        bexp.right = rrexp
18✔
4997

18✔
4998
        return bexp
18✔
4999
}
18✔
5000

5001
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
20✔
5002
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
20✔
5003
        if err != nil {
20✔
5004
                return AnyType, err
×
5005
        }
×
5006

5007
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
20✔
5008
        if err != nil {
22✔
5009
                return AnyType, err
2✔
5010
        }
2✔
5011

5012
        return BooleanType, nil
18✔
5013
}
5014

5015
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
22✔
5016
        if t != BooleanType {
25✔
5017
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
3✔
5018
        }
3✔
5019

5020
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
5021
        if err != nil {
20✔
5022
                return err
1✔
5023
        }
1✔
5024

5025
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
18✔
5026
        if err != nil {
18✔
5027
                return err
×
5028
        }
×
5029

5030
        return nil
18✔
5031
}
5032

5033
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
576✔
5034
        rlexp, err := bexp.left.substitute(params)
576✔
5035
        if err != nil {
576✔
5036
                return nil, err
×
5037
        }
×
5038

5039
        rrexp, err := bexp.right.substitute(params)
576✔
5040
        if err != nil {
576✔
5041
                return nil, err
×
5042
        }
×
5043

5044
        bexp.left = rlexp
576✔
5045
        bexp.right = rrexp
576✔
5046

576✔
5047
        return bexp, nil
576✔
5048
}
5049

5050
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
539✔
5051
        vl, err := bexp.left.reduce(tx, row, implicitTable)
539✔
5052
        if err != nil {
540✔
5053
                return nil, err
1✔
5054
        }
1✔
5055

5056
        bl, isBool := vl.(*Bool)
538✔
5057
        if !isBool {
538✔
5058
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5059
        }
×
5060

5061
        // short-circuit evaluation
5062
        if (bl.val && bexp.op == Or) || (!bl.val && bexp.op == And) {
714✔
5063
                return &Bool{val: bl.val}, nil
176✔
5064
        }
176✔
5065

5066
        vr, err := bexp.right.reduce(tx, row, implicitTable)
362✔
5067
        if err != nil {
363✔
5068
                return nil, err
1✔
5069
        }
1✔
5070

5071
        br, isBool := vr.(*Bool)
361✔
5072
        if !isBool {
361✔
5073
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
5074
        }
×
5075

5076
        switch bexp.op {
361✔
5077
        case And:
340✔
5078
                {
680✔
5079
                        return &Bool{val: bl.val && br.val}, nil
340✔
5080
                }
340✔
5081
        case Or:
21✔
5082
                {
42✔
5083
                        return &Bool{val: bl.val || br.val}, nil
21✔
5084
                }
21✔
5085
        }
5086

5087
        return nil, ErrUnexpected
×
5088
}
5089

5090
func (bexp *BinBoolExp) selectors() []Selector {
2✔
5091
        return append(bexp.left.selectors(), bexp.right.selectors()...)
2✔
5092
}
2✔
5093

5094
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
5095
        return &BinBoolExp{
15✔
5096
                op:    bexp.op,
15✔
5097
                left:  bexp.left.reduceSelectors(row, implicitTable),
15✔
5098
                right: bexp.right.reduceSelectors(row, implicitTable),
15✔
5099
        }
15✔
5100
}
15✔
5101

5102
func (bexp *BinBoolExp) isConstant() bool {
1✔
5103
        return bexp.left.isConstant() && bexp.right.isConstant()
1✔
5104
}
1✔
5105

5106
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
153✔
5107
        if bexp.op == And {
292✔
5108
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
139✔
5109
                if err != nil {
139✔
5110
                        return err
×
5111
                }
×
5112

5113
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
139✔
5114
        }
5115

5116
        lRanges := make(map[uint32]*typedValueRange)
14✔
5117
        rRanges := make(map[uint32]*typedValueRange)
14✔
5118

14✔
5119
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
14✔
5120
        if err != nil {
14✔
5121
                return err
×
5122
        }
×
5123

5124
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
14✔
5125
        if err != nil {
14✔
5126
                return err
×
5127
        }
×
5128

5129
        for colID, lr := range lRanges {
21✔
5130
                rr, ok := rRanges[colID]
7✔
5131
                if !ok {
9✔
5132
                        continue
2✔
5133
                }
5134

5135
                err = lr.extendWith(rr)
5✔
5136
                if err != nil {
5✔
5137
                        return err
×
5138
                }
×
5139

5140
                rangesByColID[colID] = lr
5✔
5141
        }
5142

5143
        return nil
14✔
5144
}
5145

5146
func (bexp *BinBoolExp) String() string {
31✔
5147
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
31✔
5148
}
31✔
5149

5150
type ExistsBoolExp struct {
5151
        q DataSource
5152
}
5153

5154
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5155
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5156
}
1✔
5157

5158
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5159
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5160
}
1✔
5161

5162
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5163
        return bexp, nil
1✔
5164
}
1✔
5165

5166
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5167
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
2✔
5168
}
2✔
5169

5170
func (bexp *ExistsBoolExp) selectors() []Selector {
1✔
5171
        return nil
1✔
5172
}
1✔
5173

5174
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5175
        return bexp
1✔
5176
}
1✔
5177

5178
func (bexp *ExistsBoolExp) isConstant() bool {
2✔
5179
        return false
2✔
5180
}
2✔
5181

5182
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5183
        return nil
1✔
5184
}
1✔
5185

5186
func (bexp *ExistsBoolExp) String() string {
×
5187
        return ""
×
5188
}
×
5189

5190
type InSubQueryExp struct {
5191
        val   ValueExp
5192
        notIn bool
5193
        q     *SelectStmt
5194
}
5195

5196
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5197
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5198
}
1✔
5199

5200
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5201
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5202
}
1✔
5203

5204
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5205
        return bexp, nil
1✔
5206
}
1✔
5207

5208
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5209
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
5210
}
2✔
5211

5212
func (bexp *InSubQueryExp) selectors() []Selector {
1✔
5213
        return bexp.val.selectors()
1✔
5214
}
1✔
5215

5216
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5217
        return bexp
1✔
5218
}
1✔
5219

5220
func (bexp *InSubQueryExp) isConstant() bool {
1✔
5221
        return false
1✔
5222
}
1✔
5223

5224
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5225
        return nil
1✔
5226
}
1✔
5227

5228
func (bexp *InSubQueryExp) String() string {
×
5229
        return ""
×
5230
}
×
5231

5232
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
5233
type InListExp struct {
5234
        val    ValueExp
5235
        notIn  bool
5236
        values []ValueExp
5237
}
5238

5239
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
5240
        t, err := bexp.val.inferType(cols, params, implicitTable)
6✔
5241
        if err != nil {
8✔
5242
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
2✔
5243
        }
2✔
5244

5245
        for _, v := range bexp.values {
12✔
5246
                err = v.requiresType(t, cols, params, implicitTable)
8✔
5247
                if err != nil {
9✔
5248
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
5249
                }
1✔
5250
        }
5251

5252
        return BooleanType, nil
3✔
5253
}
5254

5255
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
5256
        _, err := bexp.inferType(cols, params, implicitTable)
2✔
5257
        if err != nil {
3✔
5258
                return err
1✔
5259
        }
1✔
5260

5261
        if t != BooleanType {
1✔
5262
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
×
5263
        }
×
5264

5265
        return nil
1✔
5266
}
5267

5268
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
115✔
5269
        val, err := bexp.val.substitute(params)
115✔
5270
        if err != nil {
115✔
5271
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5272
        }
×
5273

5274
        values := make([]ValueExp, len(bexp.values))
115✔
5275

115✔
5276
        for i, val := range bexp.values {
245✔
5277
                values[i], err = val.substitute(params)
130✔
5278
                if err != nil {
130✔
5279
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5280
                }
×
5281
        }
5282

5283
        return &InListExp{
115✔
5284
                val:    val,
115✔
5285
                notIn:  bexp.notIn,
115✔
5286
                values: values,
115✔
5287
        }, nil
115✔
5288
}
5289

5290
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
115✔
5291
        rval, err := bexp.val.reduce(tx, row, implicitTable)
115✔
5292
        if err != nil {
116✔
5293
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5294
        }
1✔
5295

5296
        var found bool
114✔
5297

114✔
5298
        for _, v := range bexp.values {
241✔
5299
                rv, err := v.reduce(tx, row, implicitTable)
127✔
5300
                if err != nil {
128✔
5301
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5302
                }
1✔
5303

5304
                r, err := rval.Compare(rv)
126✔
5305
                if err != nil {
127✔
5306
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5307
                }
1✔
5308

5309
                if r == 0 {
140✔
5310
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
15✔
5311
                        found = found || true
15✔
5312
                }
15✔
5313
        }
5314

5315
        return &Bool{val: found != bexp.notIn}, nil
112✔
5316
}
5317

5318
func (bexp *InListExp) selectors() []Selector {
1✔
5319
        selectors := make([]Selector, 0, len(bexp.values))
1✔
5320
        for _, v := range bexp.values {
4✔
5321
                selectors = append(selectors, v.selectors()...)
3✔
5322
        }
3✔
5323
        return append(bexp.val.selectors(), selectors...)
1✔
5324
}
5325

5326
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
5327
        values := make([]ValueExp, len(bexp.values))
10✔
5328

10✔
5329
        for i, val := range bexp.values {
20✔
5330
                values[i] = val.reduceSelectors(row, implicitTable)
10✔
5331
        }
10✔
5332

5333
        return &InListExp{
10✔
5334
                val:    bexp.val.reduceSelectors(row, implicitTable),
10✔
5335
                values: values,
10✔
5336
        }
10✔
5337
}
5338

5339
func (bexp *InListExp) isConstant() bool {
1✔
5340
        return false
1✔
5341
}
1✔
5342

5343
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
21✔
5344
        // TODO: may be determiined by smallest and bigggest value in the list
21✔
5345
        return nil
21✔
5346
}
21✔
5347

5348
func (bexp *InListExp) String() string {
1✔
5349
        values := make([]string, len(bexp.values))
1✔
5350
        for i, exp := range bexp.values {
5✔
5351
                values[i] = exp.String()
4✔
5352
        }
4✔
5353
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
1✔
5354
}
5355

5356
type FnDataSourceStmt struct {
5357
        fnCall *FnCall
5358
        as     string
5359
}
5360

5361
func (stmt *FnDataSourceStmt) readOnly() bool {
1✔
5362
        return true
1✔
5363
}
1✔
5364

5365
func (stmt *FnDataSourceStmt) requiredPrivileges() []SQLPrivilege {
1✔
5366
        return nil
1✔
5367
}
1✔
5368

5369
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
5370
        return tx, nil
×
5371
}
×
5372

5373
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5374
        return nil
1✔
5375
}
1✔
5376

5377
func (stmt *FnDataSourceStmt) Alias() string {
24✔
5378
        if stmt.as != "" {
26✔
5379
                return stmt.as
2✔
5380
        }
2✔
5381

5382
        switch strings.ToUpper(stmt.fnCall.fn) {
22✔
5383
        case DatabasesFnCall:
3✔
5384
                {
6✔
5385
                        return "databases"
3✔
5386
                }
3✔
5387
        case TablesFnCall:
5✔
5388
                {
10✔
5389
                        return "tables"
5✔
5390
                }
5✔
5391
        case TableFnCall:
×
5392
                {
×
5393
                        return "table"
×
5394
                }
×
5395
        case UsersFnCall:
7✔
5396
                {
14✔
5397
                        return "users"
7✔
5398
                }
7✔
5399
        case ColumnsFnCall:
3✔
5400
                {
6✔
5401
                        return "columns"
3✔
5402
                }
3✔
5403
        case IndexesFnCall:
2✔
5404
                {
4✔
5405
                        return "indexes"
2✔
5406
                }
2✔
5407
        case GrantsFnCall:
2✔
5408
                return "grants"
2✔
5409
        }
5410

5411
        // not reachable
5412
        return ""
×
5413
}
5414

5415
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
25✔
5416
        if stmt.fnCall == nil {
25✔
5417
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
×
5418
        }
×
5419

5420
        switch strings.ToUpper(stmt.fnCall.fn) {
25✔
5421
        case DatabasesFnCall:
5✔
5422
                {
10✔
5423
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
5✔
5424
                }
5✔
5425
        case TablesFnCall:
5✔
5426
                {
10✔
5427
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
5✔
5428
                }
5✔
5429
        case TableFnCall:
×
5430
                {
×
5431
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
×
5432
                }
×
5433
        case UsersFnCall:
7✔
5434
                {
14✔
5435
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
7✔
5436
                }
7✔
5437
        case ColumnsFnCall:
3✔
5438
                {
6✔
5439
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
5440
                }
3✔
5441
        case IndexesFnCall:
3✔
5442
                {
6✔
5443
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
3✔
5444
                }
3✔
5445
        case GrantsFnCall:
2✔
5446
                {
4✔
5447
                        return stmt.resolveListGrants(ctx, tx, params, scanSpecs)
2✔
5448
                }
2✔
5449
        }
5450

5451
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
×
5452
}
5453

5454
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5455
        if len(stmt.fnCall.params) > 0 {
5✔
5456
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
×
5457
        }
×
5458

5459
        cols := make([]ColDescriptor, 1)
5✔
5460
        cols[0] = ColDescriptor{
5✔
5461
                Column: "name",
5✔
5462
                Type:   VarcharType,
5✔
5463
        }
5✔
5464

5✔
5465
        var dbs []string
5✔
5466

5✔
5467
        if tx.engine.multidbHandler == nil {
6✔
5468
                return nil, ErrUnspecifiedMultiDBHandler
1✔
5469
        } else {
5✔
5470
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
4✔
5471
                if err != nil {
4✔
5472
                        return nil, err
×
5473
                }
×
5474
        }
5475

5476
        values := make([][]ValueExp, len(dbs))
4✔
5477

4✔
5478
        for i, db := range dbs {
12✔
5479
                values[i] = []ValueExp{&Varchar{val: db}}
8✔
5480
        }
8✔
5481

5482
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
4✔
5483
}
5484

5485
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5486
        if len(stmt.fnCall.params) > 0 {
5✔
5487
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
5488
        }
×
5489

5490
        cols := make([]ColDescriptor, 1)
5✔
5491
        cols[0] = ColDescriptor{
5✔
5492
                Column: "name",
5✔
5493
                Type:   VarcharType,
5✔
5494
        }
5✔
5495

5✔
5496
        tables := tx.catalog.GetTables()
5✔
5497

5✔
5498
        values := make([][]ValueExp, len(tables))
5✔
5499

5✔
5500
        for i, t := range tables {
14✔
5501
                values[i] = []ValueExp{&Varchar{val: t.name}}
9✔
5502
        }
9✔
5503

5504
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5505
}
5506

5507
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
×
5508
        cols := []ColDescriptor{
×
5509
                {
×
5510
                        Column: "column_name",
×
5511
                        Type:   VarcharType,
×
5512
                },
×
5513
                {
×
5514
                        Column: "type_name",
×
5515
                        Type:   VarcharType,
×
5516
                },
×
5517
                {
×
5518
                        Column: "is_nullable",
×
5519
                        Type:   BooleanType,
×
5520
                },
×
5521
                {
×
5522
                        Column: "is_indexed",
×
5523
                        Type:   VarcharType,
×
5524
                },
×
5525
                {
×
5526
                        Column: "is_auto_increment",
×
5527
                        Type:   BooleanType,
×
5528
                },
×
5529
                {
×
5530
                        Column: "is_unique",
×
5531
                        Type:   BooleanType,
×
5532
                },
×
5533
        }
×
5534

×
5535
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
5536
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
5537
        if err != nil {
×
5538
                return nil, err
×
5539
        }
×
5540

5541
        values := make([][]ValueExp, len(table.cols))
×
5542

×
5543
        for i, c := range table.cols {
×
5544
                index := "NO"
×
5545

×
5546
                indexed, err := table.IsIndexed(c.Name())
×
5547
                if err != nil {
×
5548
                        return nil, err
×
5549
                }
×
5550
                if indexed {
×
5551
                        index = "YES"
×
5552
                }
×
5553

5554
                if table.PrimaryIndex().IncludesCol(c.ID()) {
×
5555
                        index = "PRIMARY KEY"
×
5556
                }
×
5557

5558
                var unique bool
×
5559
                for _, index := range table.GetIndexesByColID(c.ID()) {
×
5560
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
5561
                                unique = true
×
5562
                                break
×
5563
                        }
5564
                }
5565

5566
                var maxLen string
×
5567

×
5568
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
5569
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
×
5570
                }
×
5571

5572
                values[i] = []ValueExp{
×
5573
                        &Varchar{val: c.colName},
×
5574
                        &Varchar{val: c.Type() + maxLen},
×
5575
                        &Bool{val: c.IsNullable()},
×
5576
                        &Varchar{val: index},
×
5577
                        &Bool{val: c.IsAutoIncremental()},
×
5578
                        &Bool{val: unique},
×
5579
                }
×
5580
        }
5581

5582
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
×
5583
}
5584

5585
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
7✔
5586
        if len(stmt.fnCall.params) > 0 {
7✔
5587
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
×
5588
        }
×
5589

5590
        cols := []ColDescriptor{
7✔
5591
                {
7✔
5592
                        Column: "name",
7✔
5593
                        Type:   VarcharType,
7✔
5594
                },
7✔
5595
                {
7✔
5596
                        Column: "permission",
7✔
5597
                        Type:   VarcharType,
7✔
5598
                },
7✔
5599
        }
7✔
5600

7✔
5601
        users, err := tx.ListUsers(ctx)
7✔
5602
        if err != nil {
7✔
5603
                return nil, err
×
5604
        }
×
5605

5606
        values := make([][]ValueExp, len(users))
7✔
5607
        for i, user := range users {
23✔
5608
                perm := user.Permission()
16✔
5609

16✔
5610
                values[i] = []ValueExp{
16✔
5611
                        &Varchar{val: user.Username()},
16✔
5612
                        &Varchar{val: perm},
16✔
5613
                }
16✔
5614
        }
16✔
5615
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
7✔
5616
}
5617

5618
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5619
        if len(stmt.fnCall.params) != 1 {
3✔
5620
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
×
5621
        }
×
5622

5623
        cols := []ColDescriptor{
3✔
5624
                {
3✔
5625
                        Column: "table",
3✔
5626
                        Type:   VarcharType,
3✔
5627
                },
3✔
5628
                {
3✔
5629
                        Column: "name",
3✔
5630
                        Type:   VarcharType,
3✔
5631
                },
3✔
5632
                {
3✔
5633
                        Column: "type",
3✔
5634
                        Type:   VarcharType,
3✔
5635
                },
3✔
5636
                {
3✔
5637
                        Column: "max_length",
3✔
5638
                        Type:   IntegerType,
3✔
5639
                },
3✔
5640
                {
3✔
5641
                        Column: "nullable",
3✔
5642
                        Type:   BooleanType,
3✔
5643
                },
3✔
5644
                {
3✔
5645
                        Column: "auto_increment",
3✔
5646
                        Type:   BooleanType,
3✔
5647
                },
3✔
5648
                {
3✔
5649
                        Column: "indexed",
3✔
5650
                        Type:   BooleanType,
3✔
5651
                },
3✔
5652
                {
3✔
5653
                        Column: "primary",
3✔
5654
                        Type:   BooleanType,
3✔
5655
                },
3✔
5656
                {
3✔
5657
                        Column: "unique",
3✔
5658
                        Type:   BooleanType,
3✔
5659
                },
3✔
5660
        }
3✔
5661

3✔
5662
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5663
        if err != nil {
3✔
5664
                return nil, err
×
5665
        }
×
5666

5667
        tableName, err := val.reduce(tx, nil, "")
3✔
5668
        if err != nil {
3✔
5669
                return nil, err
×
5670
        }
×
5671

5672
        if tableName.Type() != VarcharType {
3✔
5673
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5674
        }
×
5675

5676
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5677
        if err != nil {
3✔
5678
                return nil, err
×
5679
        }
×
5680

5681
        values := make([][]ValueExp, len(table.cols))
3✔
5682

3✔
5683
        for i, c := range table.cols {
11✔
5684
                indexed, err := table.IsIndexed(c.Name())
8✔
5685
                if err != nil {
8✔
5686
                        return nil, err
×
5687
                }
×
5688

5689
                var unique bool
8✔
5690
                for _, index := range table.indexesByColID[c.id] {
16✔
5691
                        if index.IsUnique() && len(index.Cols()) == 1 {
11✔
5692
                                unique = true
3✔
5693
                                break
3✔
5694
                        }
5695
                }
5696

5697
                values[i] = []ValueExp{
8✔
5698
                        &Varchar{val: table.name},
8✔
5699
                        &Varchar{val: c.colName},
8✔
5700
                        &Varchar{val: c.colType},
8✔
5701
                        &Integer{val: int64(c.MaxLen())},
8✔
5702
                        &Bool{val: c.IsNullable()},
8✔
5703
                        &Bool{val: c.autoIncrement},
8✔
5704
                        &Bool{val: indexed},
8✔
5705
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
8✔
5706
                        &Bool{val: unique},
8✔
5707
                }
8✔
5708
        }
5709

5710
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5711
}
5712

5713
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5714
        if len(stmt.fnCall.params) != 1 {
3✔
5715
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
×
5716
        }
×
5717

5718
        cols := []ColDescriptor{
3✔
5719
                {
3✔
5720
                        Column: "table",
3✔
5721
                        Type:   VarcharType,
3✔
5722
                },
3✔
5723
                {
3✔
5724
                        Column: "name",
3✔
5725
                        Type:   VarcharType,
3✔
5726
                },
3✔
5727
                {
3✔
5728
                        Column: "unique",
3✔
5729
                        Type:   BooleanType,
3✔
5730
                },
3✔
5731
                {
3✔
5732
                        Column: "primary",
3✔
5733
                        Type:   BooleanType,
3✔
5734
                },
3✔
5735
        }
3✔
5736

3✔
5737
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5738
        if err != nil {
3✔
5739
                return nil, err
×
5740
        }
×
5741

5742
        tableName, err := val.reduce(tx, nil, "")
3✔
5743
        if err != nil {
3✔
5744
                return nil, err
×
5745
        }
×
5746

5747
        if tableName.Type() != VarcharType {
3✔
5748
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5749
        }
×
5750

5751
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5752
        if err != nil {
3✔
5753
                return nil, err
×
5754
        }
×
5755

5756
        values := make([][]ValueExp, len(table.indexes))
3✔
5757

3✔
5758
        for i, index := range table.indexes {
10✔
5759
                values[i] = []ValueExp{
7✔
5760
                        &Varchar{val: table.name},
7✔
5761
                        &Varchar{val: index.Name()},
7✔
5762
                        &Bool{val: index.unique},
7✔
5763
                        &Bool{val: index.IsPrimary()},
7✔
5764
                }
7✔
5765
        }
7✔
5766

5767
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5768
}
5769

5770
func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
2✔
5771
        if len(stmt.fnCall.params) > 1 {
2✔
5772
                return nil, fmt.Errorf("%w: function '%s' expect at most one parameter of type %s", ErrIllegalArguments, GrantsFnCall, VarcharType)
×
5773
        }
×
5774

5775
        var username string
2✔
5776
        if len(stmt.fnCall.params) == 1 {
3✔
5777
                val, err := stmt.fnCall.params[0].substitute(params)
1✔
5778
                if err != nil {
1✔
5779
                        return nil, err
×
5780
                }
×
5781

5782
                userVal, err := val.reduce(tx, nil, "")
1✔
5783
                if err != nil {
1✔
5784
                        return nil, err
×
5785
                }
×
5786

5787
                if userVal.Type() != VarcharType {
1✔
5788
                        return nil, fmt.Errorf("%w: expected '%s' for username but type '%s' given instead", ErrIllegalArguments, VarcharType, userVal.Type())
×
5789
                }
×
5790
                username, _ = userVal.RawValue().(string)
1✔
5791
        }
5792

5793
        cols := []ColDescriptor{
2✔
5794
                {
2✔
5795
                        Column: "user",
2✔
5796
                        Type:   VarcharType,
2✔
5797
                },
2✔
5798
                {
2✔
5799
                        Column: "privilege",
2✔
5800
                        Type:   VarcharType,
2✔
5801
                },
2✔
5802
        }
2✔
5803

2✔
5804
        var err error
2✔
5805
        var users []User
2✔
5806

2✔
5807
        if tx.engine.multidbHandler == nil {
2✔
5808
                return nil, ErrUnspecifiedMultiDBHandler
×
5809
        } else {
2✔
5810
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
2✔
5811
                if err != nil {
2✔
5812
                        return nil, err
×
5813
                }
×
5814
        }
5815

5816
        values := make([][]ValueExp, 0, len(users))
2✔
5817

2✔
5818
        for _, user := range users {
4✔
5819
                if username == "" || user.Username() == username {
4✔
5820
                        for _, p := range user.SQLPrivileges() {
6✔
5821
                                values = append(values, []ValueExp{
4✔
5822
                                        &Varchar{val: user.Username()},
4✔
5823
                                        &Varchar{val: string(p)},
4✔
5824
                                })
4✔
5825
                        }
4✔
5826
                }
5827
        }
5828

5829
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
2✔
5830
}
5831

5832
// DropTableStmt represents a statement to delete a table.
5833
type DropTableStmt struct {
5834
        table string
5835
}
5836

5837
func NewDropTableStmt(table string) *DropTableStmt {
6✔
5838
        return &DropTableStmt{table: table}
6✔
5839
}
6✔
5840

5841
func (stmt *DropTableStmt) readOnly() bool {
1✔
5842
        return false
1✔
5843
}
1✔
5844

5845
func (stmt *DropTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
5846
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5847
}
1✔
5848

5849
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5850
        return nil
1✔
5851
}
1✔
5852

5853
/*
5854
Exec executes the delete table statement.
5855
It the table exists, if not it does nothing.
5856
If the table exists, it deletes all the indexes and the table itself.
5857
Note that this is a soft delete of the index and table key,
5858
the data is not deleted, but the metadata is updated.
5859
*/
5860
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
7✔
5861
        if !tx.catalog.ExistTable(stmt.table) {
8✔
5862
                return nil, ErrTableDoesNotExist
1✔
5863
        }
1✔
5864

5865
        table, err := tx.catalog.GetTableByName(stmt.table)
6✔
5866
        if err != nil {
6✔
5867
                return nil, err
×
5868
        }
×
5869

5870
        // delete table
5871
        mappedKey := MapKey(
6✔
5872
                tx.sqlPrefix(),
6✔
5873
                catalogTablePrefix,
6✔
5874
                EncodeID(DatabaseID),
6✔
5875
                EncodeID(table.id),
6✔
5876
        )
6✔
5877
        err = tx.delete(ctx, mappedKey)
6✔
5878
        if err != nil {
6✔
5879
                return nil, err
×
5880
        }
×
5881

5882
        // delete columns
5883
        cols := table.ColumnsByID()
6✔
5884
        for _, col := range cols {
26✔
5885
                mappedKey := MapKey(
20✔
5886
                        tx.sqlPrefix(),
20✔
5887
                        catalogColumnPrefix,
20✔
5888
                        EncodeID(DatabaseID),
20✔
5889
                        EncodeID(col.table.id),
20✔
5890
                        EncodeID(col.id),
20✔
5891
                        []byte(col.colType),
20✔
5892
                )
20✔
5893
                err = tx.delete(ctx, mappedKey)
20✔
5894
                if err != nil {
20✔
5895
                        return nil, err
×
5896
                }
×
5897
        }
5898

5899
        // delete checks
5900
        for name := range table.checkConstraints {
6✔
5901
                key := MapKey(
×
5902
                        tx.sqlPrefix(),
×
5903
                        catalogCheckPrefix,
×
5904
                        EncodeID(DatabaseID),
×
5905
                        EncodeID(table.id),
×
5906
                        []byte(name),
×
5907
                )
×
5908

×
5909
                if err := tx.delete(ctx, key); err != nil {
×
5910
                        return nil, err
×
5911
                }
×
5912
        }
5913

5914
        // delete indexes
5915
        for _, index := range table.indexes {
13✔
5916
                mappedKey := MapKey(
7✔
5917
                        tx.sqlPrefix(),
7✔
5918
                        catalogIndexPrefix,
7✔
5919
                        EncodeID(DatabaseID),
7✔
5920
                        EncodeID(table.id),
7✔
5921
                        EncodeID(index.id),
7✔
5922
                )
7✔
5923
                err = tx.delete(ctx, mappedKey)
7✔
5924
                if err != nil {
7✔
5925
                        return nil, err
×
5926
                }
×
5927

5928
                indexKey := MapKey(
7✔
5929
                        tx.sqlPrefix(),
7✔
5930
                        MappedPrefix,
7✔
5931
                        EncodeID(table.id),
7✔
5932
                        EncodeID(index.id),
7✔
5933
                )
7✔
5934
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
14✔
5935
                        return sqlTx.engine.store.DeleteIndex(indexKey)
7✔
5936
                })
7✔
5937
                if err != nil {
7✔
5938
                        return nil, err
×
5939
                }
×
5940
        }
5941

5942
        err = tx.catalog.deleteTable(table)
6✔
5943
        if err != nil {
6✔
5944
                return nil, err
×
5945
        }
×
5946

5947
        tx.mutatedCatalog = true
6✔
5948

6✔
5949
        return tx, nil
6✔
5950
}
5951

5952
// DropIndexStmt represents a statement to delete a table.
5953
type DropIndexStmt struct {
5954
        table string
5955
        cols  []string
5956
}
5957

5958
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
4✔
5959
        return &DropIndexStmt{table: table, cols: cols}
4✔
5960
}
4✔
5961

5962
func (stmt *DropIndexStmt) readOnly() bool {
1✔
5963
        return false
1✔
5964
}
1✔
5965

5966
func (stmt *DropIndexStmt) requiredPrivileges() []SQLPrivilege {
1✔
5967
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5968
}
1✔
5969

5970
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5971
        return nil
1✔
5972
}
1✔
5973

5974
/*
5975
Exec executes the delete index statement.
5976
If the index exists, it deletes it. Note that this is a soft delete of the index
5977
the data is not deleted, but the metadata is updated.
5978
*/
5979
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
5980
        if !tx.catalog.ExistTable(stmt.table) {
7✔
5981
                return nil, ErrTableDoesNotExist
1✔
5982
        }
1✔
5983

5984
        table, err := tx.catalog.GetTableByName(stmt.table)
5✔
5985
        if err != nil {
5✔
5986
                return nil, err
×
5987
        }
×
5988

5989
        cols := make([]*Column, len(stmt.cols))
5✔
5990

5✔
5991
        for i, colName := range stmt.cols {
10✔
5992
                col, err := table.GetColumnByName(colName)
5✔
5993
                if err != nil {
5✔
5994
                        return nil, err
×
5995
                }
×
5996

5997
                cols[i] = col
5✔
5998
        }
5999

6000
        index, err := table.GetIndexByName(indexName(table.name, cols))
5✔
6001
        if err != nil {
5✔
6002
                return nil, err
×
6003
        }
×
6004

6005
        // delete index
6006
        mappedKey := MapKey(
5✔
6007
                tx.sqlPrefix(),
5✔
6008
                catalogIndexPrefix,
5✔
6009
                EncodeID(DatabaseID),
5✔
6010
                EncodeID(table.id),
5✔
6011
                EncodeID(index.id),
5✔
6012
        )
5✔
6013
        err = tx.delete(ctx, mappedKey)
5✔
6014
        if err != nil {
5✔
6015
                return nil, err
×
6016
        }
×
6017

6018
        indexKey := MapKey(
5✔
6019
                tx.sqlPrefix(),
5✔
6020
                MappedPrefix,
5✔
6021
                EncodeID(table.id),
5✔
6022
                EncodeID(index.id),
5✔
6023
        )
5✔
6024

5✔
6025
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
9✔
6026
                return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
6027
        })
4✔
6028
        if err != nil {
5✔
6029
                return nil, err
×
6030
        }
×
6031

6032
        err = table.deleteIndex(index)
5✔
6033
        if err != nil {
6✔
6034
                return nil, err
1✔
6035
        }
1✔
6036

6037
        tx.mutatedCatalog = true
4✔
6038

4✔
6039
        return tx, nil
4✔
6040
}
6041

6042
type SQLPrivilege string
6043

6044
const (
6045
        SQLPrivilegeSelect SQLPrivilege = "SELECT"
6046
        SQLPrivilegeCreate SQLPrivilege = "CREATE"
6047
        SQLPrivilegeInsert SQLPrivilege = "INSERT"
6048
        SQLPrivilegeUpdate SQLPrivilege = "UPDATE"
6049
        SQLPrivilegeDelete SQLPrivilege = "DELETE"
6050
        SQLPrivilegeDrop   SQLPrivilege = "DROP"
6051
        SQLPrivilegeAlter  SQLPrivilege = "ALTER"
6052
)
6053

6054
var allPrivileges = []SQLPrivilege{
6055
        SQLPrivilegeSelect,
6056
        SQLPrivilegeCreate,
6057
        SQLPrivilegeInsert,
6058
        SQLPrivilegeUpdate,
6059
        SQLPrivilegeDelete,
6060
        SQLPrivilegeDrop,
6061
        SQLPrivilegeAlter,
6062
}
6063

6064
func DefaultSQLPrivilegesForPermission(p Permission) []SQLPrivilege {
295✔
6065
        switch p {
295✔
6066
        case PermissionSysAdmin, PermissionAdmin, PermissionReadWrite:
284✔
6067
                return allPrivileges
284✔
6068
        case PermissionReadOnly:
11✔
6069
                return []SQLPrivilege{SQLPrivilegeSelect}
11✔
6070
        }
6071
        return nil
×
6072
}
6073

6074
type AlterPrivilegesStmt struct {
6075
        database   string
6076
        user       string
6077
        privileges []SQLPrivilege
6078
        isGrant    bool
6079
}
6080

6081
func (stmt *AlterPrivilegesStmt) readOnly() bool {
2✔
6082
        return false
2✔
6083
}
2✔
6084

6085
func (stmt *AlterPrivilegesStmt) requiredPrivileges() []SQLPrivilege {
2✔
6086
        return nil
2✔
6087
}
2✔
6088

6089
func (stmt *AlterPrivilegesStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2✔
6090
        if tx.IsExplicitCloseRequired() {
3✔
6091
                return nil, fmt.Errorf("%w: user privileges modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
6092
        }
1✔
6093

6094
        if tx.engine.multidbHandler == nil {
1✔
6095
                return nil, ErrUnspecifiedMultiDBHandler
×
6096
        }
×
6097

6098
        var err error
1✔
6099
        if stmt.isGrant {
1✔
6100
                err = tx.engine.multidbHandler.GrantSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
×
6101
        } else {
1✔
6102
                err = tx.engine.multidbHandler.RevokeSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
1✔
6103
        }
1✔
6104
        return nil, err
1✔
6105
}
6106

6107
func (stmt *AlterPrivilegesStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
6108
        return nil
1✔
6109
}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc