• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

codenotary / immudb / 12273047947

11 Dec 2024 09:06AM UTC coverage: 89.286% (+0.02%) from 89.266%
12273047947

Pull #2036

gh-ci

ostafen
chore(embedded/sql): Add support for core pg_catalog tables (pg_class, pg_namespace, pg_roles)

Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>
Pull Request #2036: chore(embedded/sql): Add support for core pg_catalog tables (pg_class…

160 of 184 new or added lines in 12 files covered. (86.96%)

1 existing line in 1 file now uncovered.

37641 of 42158 relevant lines covered (89.29%)

150814.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.56
/embedded/sql/stmt.go
1
/*
2
Copyright 2024 Codenotary Inc. All rights reserved.
3

4
SPDX-License-Identifier: BUSL-1.1
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
    https://mariadb.com/bsl11/
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
*/
16

17
package sql
18

19
import (
20
        "bytes"
21
        "context"
22
        "encoding/binary"
23
        "encoding/hex"
24
        "errors"
25
        "fmt"
26
        "math"
27
        "regexp"
28
        "strconv"
29
        "strings"
30
        "time"
31

32
        "github.com/codenotary/immudb/embedded/store"
33
        "github.com/google/uuid"
34
)
35

36
const (
37
        catalogPrefix          = "CTL."
38
        catalogTablePrefix     = "CTL.TABLE."     // (key=CTL.TABLE.{1}{tableID}, value={tableNAME})
39
        catalogColumnPrefix    = "CTL.COLUMN."    // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
40
        catalogIndexPrefix     = "CTL.INDEX."     // (key=CTL.INDEX.{1}{tableID}{indexID}, value={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)})
41
        catalogCheckPrefix     = "CTL.CHECK."     // (key=CTL.CHECK.{1}{tableID}{checkID}, value={nameLen}{name}{expText})
42
        catalogPrivilegePrefix = "CTL.PRIVILEGE." // (key=CTL.COLUMN.{1}{tableID}{colID}{colTYPE}, value={(auto_incremental | nullable){maxLen}{colNAME}})
43

44
        RowPrefix    = "R." // (key=R.{1}{tableID}{0}({null}({pkVal}{padding}{pkValLen})?)+, value={count (colID valLen val)+})
45
        MappedPrefix = "M." // (key=M.{tableID}{indexID}({null}({val}{padding}{valLen})?)*({pkVal}{padding}{pkValLen})+, value={count (colID valLen val)+})
46
)
47

48
const (
49
        DatabaseID = uint32(1) // deprecated but left to maintain backwards compatibility
50
        PKIndexID  = uint32(0)
51
)
52

53
const (
54
        nullableFlag      byte = 1 << iota
55
        autoIncrementFlag byte = 1 << iota
56
)
57

58
const (
59
        revCol        = "_rev"
60
        txMetadataCol = "_tx_metadata"
61
)
62

63
var reservedColumns = map[string]struct{}{
64
        revCol:        {},
65
        txMetadataCol: {},
66
}
67

68
func isReservedCol(col string) bool {
15,604✔
69
        _, ok := reservedColumns[col]
15,604✔
70
        return ok
15,604✔
71
}
15,604✔
72

73
type SQLValueType = string
74

75
const (
76
        IntegerType   SQLValueType = "INTEGER"
77
        BooleanType   SQLValueType = "BOOLEAN"
78
        VarcharType   SQLValueType = "VARCHAR"
79
        UUIDType      SQLValueType = "UUID"
80
        BLOBType      SQLValueType = "BLOB"
81
        Float64Type   SQLValueType = "FLOAT"
82
        TimestampType SQLValueType = "TIMESTAMP"
83
        AnyType       SQLValueType = "ANY"
84
        JSONType      SQLValueType = "JSON"
85
)
86

87
func IsNumericType(t SQLValueType) bool {
217✔
88
        return t == IntegerType || t == Float64Type
217✔
89
}
217✔
90

91
type Permission = string
92

93
const (
94
        PermissionReadOnly  Permission = "READ"
95
        PermissionReadWrite Permission = "READWRITE"
96
        PermissionAdmin     Permission = "ADMIN"
97
        PermissionSysAdmin  Permission = "SYSADMIN"
98
)
99

100
func PermissionFromCode(code uint32) Permission {
629✔
101
        switch code {
629✔
102
        case 1:
16✔
103
                {
32✔
104
                        return PermissionReadOnly
16✔
105
                }
16✔
106
        case 2:
10✔
107
                {
20✔
108
                        return PermissionReadWrite
10✔
109
                }
10✔
110
        case 254:
20✔
111
                {
40✔
112
                        return PermissionAdmin
20✔
113
                }
20✔
114
        }
115
        return PermissionSysAdmin
583✔
116
}
117

118
type AggregateFn = string
119

120
const (
121
        COUNT AggregateFn = "COUNT"
122
        SUM   AggregateFn = "SUM"
123
        MAX   AggregateFn = "MAX"
124
        MIN   AggregateFn = "MIN"
125
        AVG   AggregateFn = "AVG"
126
)
127

128
type CmpOperator = int
129

130
const (
131
        EQ CmpOperator = iota
132
        NE
133
        LT
134
        LE
135
        GT
136
        GE
137
)
138

139
func CmpOperatorToString(op CmpOperator) string {
20✔
140
        switch op {
20✔
141
        case EQ:
8✔
142
                return "="
8✔
143
        case NE:
2✔
144
                return "!="
2✔
145
        case LT:
1✔
146
                return "<"
1✔
147
        case LE:
3✔
148
                return "<="
3✔
149
        case GT:
1✔
150
                return ">"
1✔
151
        case GE:
5✔
152
                return ">="
5✔
153
        }
154
        return ""
×
155
}
156

157
type LogicOperator = int
158

159
const (
160
        And LogicOperator = iota
161
        Or
162
)
163

164
func LogicOperatorToString(op LogicOperator) string {
31✔
165
        if op == And {
46✔
166
                return "AND"
15✔
167
        }
15✔
168
        return "OR"
16✔
169
}
170

171
type NumOperator = int
172

173
const (
174
        ADDOP NumOperator = iota
175
        SUBSOP
176
        DIVOP
177
        MULTOP
178
        MODOP
179
)
180

181
func NumOperatorString(op NumOperator) string {
18✔
182
        switch op {
18✔
183
        case ADDOP:
6✔
184
                return "+"
6✔
185
        case SUBSOP:
1✔
186
                return "-"
1✔
187
        case DIVOP:
5✔
188
                return "/"
5✔
189
        case MULTOP:
5✔
190
                return "*"
5✔
191
        case MODOP:
1✔
192
                return "%"
1✔
193
        }
194
        return ""
×
195
}
196

197
type JoinType = int
198

199
const (
200
        InnerJoin JoinType = iota
201
        LeftJoin
202
        RightJoin
203
)
204

205
type SQLStmt interface {
206
        readOnly() bool
207
        requiredPrivileges() []SQLPrivilege
208
        execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error)
209
        inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error
210
}
211

212
type BeginTransactionStmt struct {
213
}
214

215
func (stmt *BeginTransactionStmt) readOnly() bool {
9✔
216
        return true
9✔
217
}
9✔
218

219
func (stmt *BeginTransactionStmt) requiredPrivileges() []SQLPrivilege {
9✔
220
        return nil
9✔
221
}
9✔
222

223
func (stmt *BeginTransactionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
3✔
224
        return nil
3✔
225
}
3✔
226

227
func (stmt *BeginTransactionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
61✔
228
        if tx.IsExplicitCloseRequired() {
62✔
229
                return nil, ErrNestedTxNotSupported
1✔
230
        }
1✔
231

232
        err := tx.RequireExplicitClose()
60✔
233
        if err == nil {
119✔
234
                // current tx can be reused as no changes were already made
59✔
235
                return tx, nil
59✔
236
        }
59✔
237

238
        // commit current transaction and start a fresh one
239

240
        err = tx.Commit(ctx)
1✔
241
        if err != nil {
1✔
242
                return nil, err
×
243
        }
×
244

245
        return tx.engine.NewTx(ctx, tx.opts.WithExplicitClose(true))
1✔
246
}
247

248
type CommitStmt struct {
249
}
250

251
func (stmt *CommitStmt) readOnly() bool {
113✔
252
        return true
113✔
253
}
113✔
254

255
func (stmt *CommitStmt) requiredPrivileges() []SQLPrivilege {
113✔
256
        return nil
113✔
257
}
113✔
258

259
func (stmt *CommitStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
260
        return nil
1✔
261
}
1✔
262

263
func (stmt *CommitStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
159✔
264
        if !tx.IsExplicitCloseRequired() {
160✔
265
                return nil, ErrNoOngoingTx
1✔
266
        }
1✔
267

268
        return nil, tx.Commit(ctx)
158✔
269
}
270

271
type RollbackStmt struct {
272
}
273

274
func (stmt *RollbackStmt) readOnly() bool {
1✔
275
        return true
1✔
276
}
1✔
277

278
func (stmt *RollbackStmt) requiredPrivileges() []SQLPrivilege {
1✔
279
        return nil
1✔
280
}
1✔
281

282
func (stmt *RollbackStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
283
        return nil
1✔
284
}
1✔
285

286
func (stmt *RollbackStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
287
        if !tx.IsExplicitCloseRequired() {
5✔
288
                return nil, ErrNoOngoingTx
1✔
289
        }
1✔
290

291
        return nil, tx.Cancel()
3✔
292
}
293

294
type CreateDatabaseStmt struct {
295
        DB          string
296
        ifNotExists bool
297
}
298

299
func (stmt *CreateDatabaseStmt) readOnly() bool {
14✔
300
        return false
14✔
301
}
14✔
302

303
func (stmt *CreateDatabaseStmt) requiredPrivileges() []SQLPrivilege {
14✔
304
        return []SQLPrivilege{SQLPrivilegeCreate}
14✔
305
}
14✔
306

307
func (stmt *CreateDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
308
        return nil
4✔
309
}
4✔
310

311
func (stmt *CreateDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
312
        if tx.IsExplicitCloseRequired() {
16✔
313
                return nil, fmt.Errorf("%w: database creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
314
        }
1✔
315

316
        if tx.engine.multidbHandler == nil {
16✔
317
                return nil, ErrUnspecifiedMultiDBHandler
2✔
318
        }
2✔
319

320
        return nil, tx.engine.multidbHandler.CreateDatabase(ctx, stmt.DB, stmt.ifNotExists)
12✔
321
}
322

323
type UseDatabaseStmt struct {
324
        DB string
325
}
326

327
func (stmt *UseDatabaseStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
328
        return nil
1✔
329
}
1✔
330

331
func (stmt *UseDatabaseStmt) readOnly() bool {
9✔
332
        return true
9✔
333
}
9✔
334

335
func (stmt *UseDatabaseStmt) requiredPrivileges() []SQLPrivilege {
9✔
336
        return nil
9✔
337
}
9✔
338

339
func (stmt *UseDatabaseStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
340
        if tx.IsExplicitCloseRequired() {
10✔
341
                return nil, fmt.Errorf("%w: database selection can NOT be executed within a transaction block", ErrNonTransactionalStmt)
1✔
342
        }
1✔
343

344
        if tx.engine.multidbHandler == nil {
9✔
345
                return nil, ErrUnspecifiedMultiDBHandler
1✔
346
        }
1✔
347

348
        return tx, tx.engine.multidbHandler.UseDatabase(ctx, stmt.DB)
7✔
349
}
350

351
type UseSnapshotStmt struct {
352
        period period
353
}
354

355
func (stmt *UseSnapshotStmt) readOnly() bool {
1✔
356
        return true
1✔
357
}
1✔
358

359
func (stmt *UseSnapshotStmt) requiredPrivileges() []SQLPrivilege {
1✔
360
        return nil
1✔
361
}
1✔
362

363
func (stmt *UseSnapshotStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
364
        return nil
1✔
365
}
1✔
366

367
func (stmt *UseSnapshotStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
1✔
368
        return nil, ErrNoSupported
1✔
369
}
1✔
370

371
type CreateUserStmt struct {
372
        username   string
373
        password   string
374
        permission Permission
375
}
376

377
func (stmt *CreateUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
378
        return nil
1✔
379
}
1✔
380

381
func (stmt *CreateUserStmt) readOnly() bool {
5✔
382
        return false
5✔
383
}
5✔
384

385
func (stmt *CreateUserStmt) requiredPrivileges() []SQLPrivilege {
5✔
386
        return []SQLPrivilege{SQLPrivilegeCreate}
5✔
387
}
5✔
388

389
func (stmt *CreateUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
390
        if tx.IsExplicitCloseRequired() {
7✔
391
                return nil, fmt.Errorf("%w: user creation can not be done within a transaction", ErrNonTransactionalStmt)
1✔
392
        }
1✔
393

394
        if tx.engine.multidbHandler == nil {
6✔
395
                return nil, ErrUnspecifiedMultiDBHandler
1✔
396
        }
1✔
397

398
        return nil, tx.engine.multidbHandler.CreateUser(ctx, stmt.username, stmt.password, stmt.permission)
4✔
399
}
400

401
type AlterUserStmt struct {
402
        username   string
403
        password   string
404
        permission Permission
405
}
406

407
func (stmt *AlterUserStmt) readOnly() bool {
4✔
408
        return false
4✔
409
}
4✔
410

411
func (stmt *AlterUserStmt) requiredPrivileges() []SQLPrivilege {
4✔
412
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
413
}
4✔
414

415
func (stmt *AlterUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
416
        return nil
1✔
417
}
1✔
418

419
func (stmt *AlterUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
5✔
420
        if tx.IsExplicitCloseRequired() {
6✔
421
                return nil, fmt.Errorf("%w: user modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
422
        }
1✔
423

424
        if tx.engine.multidbHandler == nil {
5✔
425
                return nil, ErrUnspecifiedMultiDBHandler
1✔
426
        }
1✔
427

428
        return nil, tx.engine.multidbHandler.AlterUser(ctx, stmt.username, stmt.password, stmt.permission)
3✔
429
}
430

431
type DropUserStmt struct {
432
        username string
433
}
434

435
func (stmt *DropUserStmt) readOnly() bool {
3✔
436
        return false
3✔
437
}
3✔
438

439
func (stmt *DropUserStmt) requiredPrivileges() []SQLPrivilege {
3✔
440
        return []SQLPrivilege{SQLPrivilegeDrop}
3✔
441
}
3✔
442

443
func (stmt *DropUserStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
444
        return nil
1✔
445
}
1✔
446

447
func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
3✔
448
        if tx.IsExplicitCloseRequired() {
4✔
449
                return nil, fmt.Errorf("%w: user deletion can not be done within a transaction", ErrNonTransactionalStmt)
1✔
450
        }
1✔
451

452
        if tx.engine.multidbHandler == nil {
3✔
453
                return nil, ErrUnspecifiedMultiDBHandler
1✔
454
        }
1✔
455

456
        return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
1✔
457
}
458

459
type CreateTableStmt struct {
460
        table       string
461
        ifNotExists bool
462
        colsSpec    []*ColSpec
463
        checks      []CheckConstraint
464
        pkColNames  []string
465
}
466

467
func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
38✔
468
        return &CreateTableStmt{table: table, ifNotExists: ifNotExists, colsSpec: colsSpec, pkColNames: pkColNames}
38✔
469
}
38✔
470

471
func (stmt *CreateTableStmt) readOnly() bool {
66✔
472
        return false
66✔
473
}
66✔
474

475
func (stmt *CreateTableStmt) requiredPrivileges() []SQLPrivilege {
64✔
476
        return []SQLPrivilege{SQLPrivilegeCreate}
64✔
477
}
64✔
478

479
func (stmt *CreateTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
4✔
480
        return nil
4✔
481
}
4✔
482

483
func zeroRow(tableName string, cols []*ColSpec) *Row {
242✔
484
        r := Row{
242✔
485
                ValuesByPosition: make([]TypedValue, len(cols)),
242✔
486
                ValuesBySelector: make(map[string]TypedValue, len(cols)),
242✔
487
        }
242✔
488

242✔
489
        for i, col := range cols {
1,020✔
490
                v := zeroForType(col.colType)
778✔
491

778✔
492
                r.ValuesByPosition[i] = v
778✔
493
                r.ValuesBySelector[EncodeSelector("", tableName, col.colName)] = v
778✔
494
        }
778✔
495
        return &r
242✔
496
}
497

498
func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
230✔
499
        if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
231✔
500
                return tx, nil
1✔
501
        }
1✔
502

503
        colSpecs := make(map[uint32]*ColSpec, len(stmt.colsSpec))
229✔
504
        for i, cs := range stmt.colsSpec {
947✔
505
                colSpecs[uint32(i)+1] = cs
718✔
506
        }
718✔
507

508
        row := zeroRow(stmt.table, stmt.colsSpec)
229✔
509
        for _, check := range stmt.checks {
238✔
510
                value, err := check.exp.reduce(tx, row, stmt.table)
9✔
511
                if err != nil {
11✔
512
                        return nil, err
2✔
513
                }
2✔
514

515
                if value.Type() != BooleanType {
7✔
516
                        return nil, ErrInvalidCheckConstraint
×
517
                }
×
518
        }
519

520
        nextUnnamedCheck := 0
227✔
521
        checks := make(map[string]CheckConstraint)
227✔
522
        for id, check := range stmt.checks {
234✔
523
                name := fmt.Sprintf("%s_check%d", stmt.table, nextUnnamedCheck+1)
7✔
524
                if check.name != "" {
9✔
525
                        name = check.name
2✔
526
                } else {
7✔
527
                        nextUnnamedCheck++
5✔
528
                }
5✔
529
                check.id = uint32(id)
7✔
530
                check.name = name
7✔
531
                checks[name] = check
7✔
532
        }
533

534
        table, err := tx.catalog.newTable(stmt.table, colSpecs, checks, uint32(len(colSpecs)))
227✔
535
        if err != nil {
233✔
536
                return nil, err
6✔
537
        }
6✔
538

539
        createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.pkColNames}
221✔
540
        _, err = createIndexStmt.execAt(ctx, tx, params)
221✔
541
        if err != nil {
226✔
542
                return nil, err
5✔
543
        }
5✔
544

545
        for _, col := range table.cols {
915✔
546
                if col.autoIncrement {
776✔
547
                        if len(table.primaryIndex.cols) > 1 || col.id != table.primaryIndex.cols[0].id {
78✔
548
                                return nil, ErrLimitedAutoIncrement
1✔
549
                        }
1✔
550
                }
551

552
                err := persistColumn(tx, col)
698✔
553
                if err != nil {
698✔
554
                        return nil, err
×
555
                }
×
556
        }
557

558
        for _, check := range checks {
222✔
559
                if err := persistCheck(tx, table, &check); err != nil {
7✔
560
                        return nil, err
×
561
                }
×
562
        }
563

564
        mappedKey := MapKey(tx.sqlPrefix(), catalogTablePrefix, EncodeID(DatabaseID), EncodeID(table.id))
215✔
565

215✔
566
        err = tx.set(mappedKey, nil, []byte(table.name))
215✔
567
        if err != nil {
215✔
568
                return nil, err
×
569
        }
×
570

571
        tx.mutatedCatalog = true
215✔
572

215✔
573
        return tx, nil
215✔
574
}
575

576
func persistColumn(tx *SQLTx, col *Column) error {
718✔
577
        //{auto_incremental | nullable}{maxLen}{colNAME})
718✔
578
        v := make([]byte, 1+4+len(col.colName))
718✔
579

718✔
580
        if col.autoIncrement {
794✔
581
                v[0] = v[0] | autoIncrementFlag
76✔
582
        }
76✔
583

584
        if col.notNull {
763✔
585
                v[0] = v[0] | nullableFlag
45✔
586
        }
45✔
587

588
        binary.BigEndian.PutUint32(v[1:], uint32(col.MaxLen()))
718✔
589

718✔
590
        copy(v[5:], []byte(col.Name()))
718✔
591

718✔
592
        mappedKey := MapKey(
718✔
593
                tx.sqlPrefix(),
718✔
594
                catalogColumnPrefix,
718✔
595
                EncodeID(DatabaseID),
718✔
596
                EncodeID(col.table.id),
718✔
597
                EncodeID(col.id),
718✔
598
                []byte(col.colType),
718✔
599
        )
718✔
600

718✔
601
        return tx.set(mappedKey, nil, v)
718✔
602
}
603

604
func persistCheck(tx *SQLTx, table *Table, check *CheckConstraint) error {
7✔
605
        mappedKey := MapKey(
7✔
606
                tx.sqlPrefix(),
7✔
607
                catalogCheckPrefix,
7✔
608
                EncodeID(DatabaseID),
7✔
609
                EncodeID(table.id),
7✔
610
                EncodeID(check.id),
7✔
611
        )
7✔
612

7✔
613
        name := check.name
7✔
614
        expText := check.exp.String()
7✔
615

7✔
616
        val := make([]byte, 2+len(name)+len(expText))
7✔
617

7✔
618
        if len(name) > 256 {
7✔
619
                return fmt.Errorf("constraint name len: %w", ErrMaxLengthExceeded)
×
620
        }
×
621

622
        val[0] = byte(len(name)) - 1
7✔
623

7✔
624
        copy(val[1:], []byte(name))
7✔
625
        copy(val[1+len(name):], []byte(expText))
7✔
626

7✔
627
        return tx.set(mappedKey, nil, val)
7✔
628
}
629

630
type ColSpec struct {
631
        colName       string
632
        colType       SQLValueType
633
        maxLen        int
634
        autoIncrement bool
635
        notNull       bool
636
}
637

638
func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {
188✔
639
        return &ColSpec{
188✔
640
                colName:       name,
188✔
641
                colType:       colType,
188✔
642
                maxLen:        maxLen,
188✔
643
                autoIncrement: autoIncrement,
188✔
644
                notNull:       notNull,
188✔
645
        }
188✔
646
}
188✔
647

648
type CreateIndexStmt struct {
649
        unique      bool
650
        ifNotExists bool
651
        table       string
652
        cols        []string
653
}
654

655
func NewCreateIndexStmt(table string, cols []string, isUnique bool) *CreateIndexStmt {
72✔
656
        return &CreateIndexStmt{unique: isUnique, table: table, cols: cols}
72✔
657
}
72✔
658

659
func (stmt *CreateIndexStmt) readOnly() bool {
7✔
660
        return false
7✔
661
}
7✔
662

663
func (stmt *CreateIndexStmt) requiredPrivileges() []SQLPrivilege {
7✔
664
        return []SQLPrivilege{SQLPrivilegeCreate}
7✔
665
}
7✔
666

667
func (stmt *CreateIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
668
        return nil
1✔
669
}
1✔
670

671
func (stmt *CreateIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
372✔
672
        if len(stmt.cols) < 1 {
373✔
673
                return nil, ErrIllegalArguments
1✔
674
        }
1✔
675

676
        if len(stmt.cols) > MaxNumberOfColumnsInIndex {
372✔
677
                return nil, ErrMaxNumberOfColumnsInIndexExceeded
1✔
678
        }
1✔
679

680
        table, err := tx.catalog.GetTableByName(stmt.table)
370✔
681
        if err != nil {
372✔
682
                return nil, err
2✔
683
        }
2✔
684

685
        colIDs := make([]uint32, len(stmt.cols))
368✔
686

368✔
687
        indexKeyLen := 0
368✔
688

368✔
689
        for i, colName := range stmt.cols {
763✔
690
                col, err := table.GetColumnByName(colName)
395✔
691
                if err != nil {
400✔
692
                        return nil, err
5✔
693
                }
5✔
694

695
                if col.Type() == JSONType {
392✔
696
                        return nil, ErrCannotIndexJson
2✔
697
                }
2✔
698

699
                if variableSizedType(col.colType) && !tx.engine.lazyIndexConstraintValidation && (col.MaxLen() == 0 || col.MaxLen() > MaxKeyLen) {
390✔
700
                        return nil, fmt.Errorf("%w: can not create index using column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
2✔
701
                }
2✔
702

703
                indexKeyLen += col.MaxLen()
386✔
704

386✔
705
                colIDs[i] = col.id
386✔
706
        }
707

708
        if !tx.engine.lazyIndexConstraintValidation && indexKeyLen > MaxKeyLen {
359✔
709
                return nil, fmt.Errorf("%w: can not create index using columns '%v'. Max key length is %d", ErrLimitedKeyType, stmt.cols, MaxKeyLen)
×
710
        }
×
711

712
        if stmt.unique && table.primaryIndex != nil {
379✔
713
                // check table is empty
20✔
714
                pkPrefix := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id))
20✔
715
                _, _, err := tx.getWithPrefix(ctx, pkPrefix, nil)
20✔
716
                if errors.Is(err, store.ErrIndexNotFound) {
20✔
717
                        return nil, ErrTableDoesNotExist
×
718
                }
×
719
                if err == nil {
21✔
720
                        return nil, ErrLimitedIndexCreation
1✔
721
                } else if !errors.Is(err, store.ErrKeyNotFound) {
20✔
722
                        return nil, err
×
723
                }
×
724
        }
725

726
        index, err := table.newIndex(stmt.unique, colIDs)
358✔
727
        if errors.Is(err, ErrIndexAlreadyExists) && stmt.ifNotExists {
360✔
728
                return tx, nil
2✔
729
        }
2✔
730
        if err != nil {
360✔
731
                return nil, err
4✔
732
        }
4✔
733

734
        // v={unique {colID1}(ASC|DESC)...{colIDN}(ASC|DESC)}
735
        // TODO: currently only ASC order is supported
736
        colSpecLen := EncIDLen + 1
352✔
737

352✔
738
        encodedValues := make([]byte, 1+len(index.cols)*colSpecLen)
352✔
739

352✔
740
        if index.IsUnique() {
586✔
741
                encodedValues[0] = 1
234✔
742
        }
234✔
743

744
        for i, col := range index.cols {
731✔
745
                copy(encodedValues[1+i*colSpecLen:], EncodeID(col.id))
379✔
746
        }
379✔
747

748
        mappedKey := MapKey(tx.sqlPrefix(), catalogIndexPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(index.id))
352✔
749

352✔
750
        err = tx.set(mappedKey, nil, encodedValues)
352✔
751
        if err != nil {
352✔
752
                return nil, err
×
753
        }
×
754

755
        tx.mutatedCatalog = true
352✔
756

352✔
757
        return tx, nil
352✔
758
}
759

760
type AddColumnStmt struct {
761
        table   string
762
        colSpec *ColSpec
763
}
764

765
func NewAddColumnStmt(table string, colSpec *ColSpec) *AddColumnStmt {
6✔
766
        return &AddColumnStmt{table: table, colSpec: colSpec}
6✔
767
}
6✔
768

769
func (stmt *AddColumnStmt) readOnly() bool {
4✔
770
        return false
4✔
771
}
4✔
772

773
func (stmt *AddColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
774
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
775
}
4✔
776

777
func (stmt *AddColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
778
        return nil
1✔
779
}
1✔
780

781
func (stmt *AddColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
782
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
783
        if err != nil {
20✔
784
                return nil, err
1✔
785
        }
1✔
786

787
        col, err := table.newColumn(stmt.colSpec)
18✔
788
        if err != nil {
24✔
789
                return nil, err
6✔
790
        }
6✔
791

792
        err = persistColumn(tx, col)
12✔
793
        if err != nil {
12✔
794
                return nil, err
×
795
        }
×
796

797
        tx.mutatedCatalog = true
12✔
798

12✔
799
        return tx, nil
12✔
800
}
801

802
type RenameTableStmt struct {
803
        oldName string
804
        newName string
805
}
806

807
func (stmt *RenameTableStmt) readOnly() bool {
1✔
808
        return false
1✔
809
}
1✔
810

811
func (stmt *RenameTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
812
        return []SQLPrivilege{SQLPrivilegeAlter}
1✔
813
}
1✔
814

815
func (stmt *RenameTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
816
        return nil
1✔
817
}
1✔
818

819
func (stmt *RenameTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
820
        table, err := tx.catalog.renameTable(stmt.oldName, stmt.newName)
6✔
821
        if err != nil {
10✔
822
                return nil, err
4✔
823
        }
4✔
824

825
        // update table name
826
        mappedKey := MapKey(
2✔
827
                tx.sqlPrefix(),
2✔
828
                catalogTablePrefix,
2✔
829
                EncodeID(DatabaseID),
2✔
830
                EncodeID(table.id),
2✔
831
        )
2✔
832
        err = tx.set(mappedKey, nil, []byte(stmt.newName))
2✔
833
        if err != nil {
2✔
834
                return nil, err
×
835
        }
×
836

837
        tx.mutatedCatalog = true
2✔
838

2✔
839
        return tx, nil
2✔
840
}
841

842
type RenameColumnStmt struct {
843
        table   string
844
        oldName string
845
        newName string
846
}
847

848
func NewRenameColumnStmt(table, oldName, newName string) *RenameColumnStmt {
3✔
849
        return &RenameColumnStmt{table: table, oldName: oldName, newName: newName}
3✔
850
}
3✔
851

852
func (stmt *RenameColumnStmt) readOnly() bool {
4✔
853
        return false
4✔
854
}
4✔
855

856
func (stmt *RenameColumnStmt) requiredPrivileges() []SQLPrivilege {
4✔
857
        return []SQLPrivilege{SQLPrivilegeAlter}
4✔
858
}
4✔
859

860
func (stmt *RenameColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
861
        return nil
1✔
862
}
1✔
863

864
func (stmt *RenameColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
10✔
865
        table, err := tx.catalog.GetTableByName(stmt.table)
10✔
866
        if err != nil {
11✔
867
                return nil, err
1✔
868
        }
1✔
869

870
        col, err := table.renameColumn(stmt.oldName, stmt.newName)
9✔
871
        if err != nil {
12✔
872
                return nil, err
3✔
873
        }
3✔
874

875
        err = persistColumn(tx, col)
6✔
876
        if err != nil {
6✔
877
                return nil, err
×
878
        }
×
879

880
        tx.mutatedCatalog = true
6✔
881

6✔
882
        return tx, nil
6✔
883
}
884

885
type DropColumnStmt struct {
886
        table   string
887
        colName string
888
}
889

890
func NewDropColumnStmt(table, colName string) *DropColumnStmt {
8✔
891
        return &DropColumnStmt{table: table, colName: colName}
8✔
892
}
8✔
893

894
func (stmt *DropColumnStmt) readOnly() bool {
2✔
895
        return false
2✔
896
}
2✔
897

898
func (stmt *DropColumnStmt) requiredPrivileges() []SQLPrivilege {
2✔
899
        return []SQLPrivilege{SQLPrivilegeDrop}
2✔
900
}
2✔
901

902
func (stmt *DropColumnStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
903
        return nil
1✔
904
}
1✔
905

906
func (stmt *DropColumnStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
19✔
907
        table, err := tx.catalog.GetTableByName(stmt.table)
19✔
908
        if err != nil {
21✔
909
                return nil, err
2✔
910
        }
2✔
911

912
        col, err := table.GetColumnByName(stmt.colName)
17✔
913
        if err != nil {
21✔
914
                return nil, err
4✔
915
        }
4✔
916

917
        err = canDropColumn(tx, table, col)
13✔
918
        if err != nil {
14✔
919
                return nil, err
1✔
920
        }
1✔
921

922
        err = table.deleteColumn(col)
12✔
923
        if err != nil {
16✔
924
                return nil, err
4✔
925
        }
4✔
926

927
        err = persistColumnDeletion(ctx, tx, col)
8✔
928
        if err != nil {
8✔
929
                return nil, err
×
930
        }
×
931

932
        tx.mutatedCatalog = true
8✔
933

8✔
934
        return tx, nil
8✔
935
}
936

937
func canDropColumn(tx *SQLTx, table *Table, col *Column) error {
13✔
938
        colSpecs := make([]*ColSpec, 0, len(table.Cols())-1)
13✔
939
        for _, c := range table.cols {
86✔
940
                if c.id != col.id {
133✔
941
                        colSpecs = append(colSpecs, &ColSpec{colName: c.Name(), colType: c.Type()})
60✔
942
                }
60✔
943
        }
944

945
        row := zeroRow(table.Name(), colSpecs)
13✔
946
        for name, check := range table.checkConstraints {
17✔
947
                _, err := check.exp.reduce(tx, row, table.name)
4✔
948
                if errors.Is(err, ErrColumnDoesNotExist) {
5✔
949
                        return fmt.Errorf("%w %s because %s constraint requires it", ErrCannotDropColumn, col.Name(), name)
1✔
950
                }
1✔
951

952
                if err != nil {
3✔
953
                        return err
×
954
                }
×
955
        }
956
        return nil
12✔
957
}
958

959
func persistColumnDeletion(ctx context.Context, tx *SQLTx, col *Column) error {
9✔
960
        mappedKey := MapKey(
9✔
961
                tx.sqlPrefix(),
9✔
962
                catalogColumnPrefix,
9✔
963
                EncodeID(DatabaseID),
9✔
964
                EncodeID(col.table.id),
9✔
965
                EncodeID(col.id),
9✔
966
                []byte(col.colType),
9✔
967
        )
9✔
968

9✔
969
        return tx.delete(ctx, mappedKey)
9✔
970
}
9✔
971

972
type DropConstraintStmt struct {
973
        table          string
974
        constraintName string
975
}
976

977
func (stmt *DropConstraintStmt) readOnly() bool {
×
978
        return false
×
979
}
×
980

981
func (stmt *DropConstraintStmt) requiredPrivileges() []SQLPrivilege {
×
982
        return []SQLPrivilege{SQLPrivilegeDrop}
×
983
}
×
984

985
func (stmt *DropConstraintStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
4✔
986
        table, err := tx.catalog.GetTableByName(stmt.table)
4✔
987
        if err != nil {
4✔
988
                return nil, err
×
989
        }
×
990

991
        id, err := table.deleteCheck(stmt.constraintName)
4✔
992
        if err != nil {
5✔
993
                return nil, err
1✔
994
        }
1✔
995

996
        err = persistCheckDeletion(ctx, tx, table.id, id)
3✔
997

3✔
998
        tx.mutatedCatalog = true
3✔
999

3✔
1000
        return tx, err
3✔
1001
}
1002

1003
func persistCheckDeletion(ctx context.Context, tx *SQLTx, tableID uint32, checkId uint32) error {
3✔
1004
        mappedKey := MapKey(
3✔
1005
                tx.sqlPrefix(),
3✔
1006
                catalogCheckPrefix,
3✔
1007
                EncodeID(DatabaseID),
3✔
1008
                EncodeID(tableID),
3✔
1009
                EncodeID(checkId),
3✔
1010
        )
3✔
1011
        return tx.delete(ctx, mappedKey)
3✔
1012
}
3✔
1013

1014
func (stmt *DropConstraintStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
1015
        return nil
×
1016
}
×
1017

1018
type UpsertIntoStmt struct {
1019
        isInsert   bool
1020
        tableRef   *tableRef
1021
        cols       []string
1022
        ds         DataSource
1023
        onConflict *OnConflictDo
1024
}
1025

1026
func (stmt *UpsertIntoStmt) readOnly() bool {
101✔
1027
        return false
101✔
1028
}
101✔
1029

1030
func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege {
101✔
1031
        privileges := stmt.privileges()
101✔
1032
        if stmt.ds != nil {
200✔
1033
                privileges = append(privileges, stmt.ds.requiredPrivileges()...)
99✔
1034
        }
99✔
1035
        return privileges
101✔
1036
}
1037

1038
func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege {
101✔
1039
        if stmt.isInsert {
190✔
1040
                return []SQLPrivilege{SQLPrivilegeInsert}
89✔
1041
        }
89✔
1042
        return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate}
12✔
1043
}
1044

1045
func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt {
120✔
1046
        return &UpsertIntoStmt{
120✔
1047
                isInsert:   isInsert,
120✔
1048
                tableRef:   NewTableRef(table, ""),
120✔
1049
                cols:       cols,
120✔
1050
                ds:         ds,
120✔
1051
                onConflict: onConflict,
120✔
1052
        }
120✔
1053
}
120✔
1054

1055
type RowSpec struct {
1056
        Values []ValueExp
1057
}
1058

1059
func NewRowSpec(values []ValueExp) *RowSpec {
129✔
1060
        return &RowSpec{
129✔
1061
                Values: values,
129✔
1062
        }
129✔
1063
}
129✔
1064

1065
type OnConflictDo struct{}
1066

1067
func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
11✔
1068
        ds, ok := stmt.ds.(*valuesDataSource)
11✔
1069
        if !ok {
11✔
1070
                return stmt.ds.inferParameters(ctx, tx, params)
×
1071
        }
×
1072

1073
        emptyDescriptors := make(map[string]ColDescriptor)
11✔
1074
        for _, row := range ds.rows {
23✔
1075
                if len(stmt.cols) != len(row.Values) {
13✔
1076
                        return ErrInvalidNumberOfValues
1✔
1077
                }
1✔
1078

1079
                for i, val := range row.Values {
36✔
1080
                        table, err := stmt.tableRef.referencedTable(tx)
25✔
1081
                        if err != nil {
26✔
1082
                                return err
1✔
1083
                        }
1✔
1084

1085
                        col, err := table.GetColumnByName(stmt.cols[i])
24✔
1086
                        if err != nil {
25✔
1087
                                return err
1✔
1088
                        }
1✔
1089

1090
                        err = val.requiresType(col.colType, emptyDescriptors, params, table.name)
23✔
1091
                        if err != nil {
25✔
1092
                                return err
2✔
1093
                        }
2✔
1094
                }
1095
        }
1096
        return nil
6✔
1097
}
1098

1099
func (stmt *UpsertIntoStmt) validate(table *Table) (map[uint32]int, error) {
2,114✔
1100
        selPosByColID := make(map[uint32]int, len(stmt.cols))
2,114✔
1101

2,114✔
1102
        for i, c := range stmt.cols {
9,734✔
1103
                col, err := table.GetColumnByName(c)
7,620✔
1104
                if err != nil {
7,622✔
1105
                        return nil, err
2✔
1106
                }
2✔
1107

1108
                _, duplicated := selPosByColID[col.id]
7,618✔
1109
                if duplicated {
7,619✔
1110
                        return nil, fmt.Errorf("%w (%s)", ErrDuplicatedColumn, col.colName)
1✔
1111
                }
1✔
1112

1113
                selPosByColID[col.id] = i
7,617✔
1114
        }
1115

1116
        return selPosByColID, nil
2,111✔
1117
}
1118

1119
func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2,117✔
1120
        table, err := stmt.tableRef.referencedTable(tx)
2,117✔
1121
        if err != nil {
2,120✔
1122
                return nil, err
3✔
1123
        }
3✔
1124

1125
        selPosByColID, err := stmt.validate(table)
2,114✔
1126
        if err != nil {
2,117✔
1127
                return nil, err
3✔
1128
        }
3✔
1129

1130
        r := &Row{
2,111✔
1131
                ValuesByPosition: make([]TypedValue, len(table.cols)),
2,111✔
1132
                ValuesBySelector: make(map[string]TypedValue),
2,111✔
1133
        }
2,111✔
1134

2,111✔
1135
        reader, err := stmt.ds.Resolve(ctx, tx, params, nil)
2,111✔
1136
        if err != nil {
2,111✔
1137
                return nil, err
×
1138
        }
×
1139
        defer reader.Close()
2,111✔
1140

2,111✔
1141
        for {
6,486✔
1142
                row, err := reader.Read(ctx)
4,375✔
1143
                if errors.Is(err, ErrNoMoreRows) {
6,443✔
1144
                        break
2,068✔
1145
                }
1146
                if err != nil {
2,317✔
1147
                        return nil, err
10✔
1148
                }
10✔
1149

1150
                if len(row.ValuesByPosition) != len(stmt.cols) {
2,299✔
1151
                        return nil, ErrInvalidNumberOfValues
2✔
1152
                }
2✔
1153

1154
                valuesByColID := make(map[uint32]TypedValue)
2,295✔
1155

2,295✔
1156
                var pkMustExist bool
2,295✔
1157

2,295✔
1158
                for colID, col := range table.colsByID {
11,740✔
1159
                        colPos, specified := selPosByColID[colID]
9,445✔
1160
                        if !specified {
10,691✔
1161
                                // TODO: Default values
1,246✔
1162
                                if col.notNull && !col.autoIncrement {
1,247✔
1163
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
1✔
1164
                                }
1✔
1165

1166
                                // inject auto-incremental pk value
1167
                                if stmt.isInsert && col.autoIncrement {
2,303✔
1168
                                        // current implementation assumes only PK can be set as autoincremental
1,058✔
1169
                                        table.maxPK++
1,058✔
1170

1,058✔
1171
                                        pkCol := table.primaryIndex.cols[0]
1,058✔
1172
                                        valuesByColID[pkCol.id] = &Integer{val: table.maxPK}
1,058✔
1173

1,058✔
1174
                                        if _, ok := tx.firstInsertedPKs[table.name]; !ok {
1,924✔
1175
                                                tx.firstInsertedPKs[table.name] = table.maxPK
866✔
1176
                                        }
866✔
1177
                                        tx.lastInsertedPKs[table.name] = table.maxPK
1,058✔
1178
                                }
1179

1180
                                continue
1,245✔
1181
                        }
1182

1183
                        // value was specified
1184
                        cVal := row.ValuesByPosition[colPos]
8,199✔
1185

8,199✔
1186
                        val, err := cVal.substitute(params)
8,199✔
1187
                        if err != nil {
8,199✔
1188
                                return nil, err
×
1189
                        }
×
1190

1191
                        rval, err := val.reduce(tx, nil, table.name)
8,199✔
1192
                        if err != nil {
8,199✔
1193
                                return nil, err
×
1194
                        }
×
1195

1196
                        if rval.IsNull() {
8,297✔
1197
                                if col.notNull || col.autoIncrement {
98✔
1198
                                        return nil, fmt.Errorf("%w (%s)", ErrNotNullableColumnCannotBeNull, col.colName)
×
1199
                                }
×
1200

1201
                                continue
98✔
1202
                        }
1203

1204
                        if col.autoIncrement {
8,120✔
1205
                                // validate specified value
19✔
1206
                                nl, isNumber := rval.RawValue().(int64)
19✔
1207
                                if !isNumber {
19✔
1208
                                        return nil, fmt.Errorf("%w (expecting numeric value)", ErrInvalidValue)
×
1209
                                }
×
1210

1211
                                pkMustExist = nl <= table.maxPK
19✔
1212

19✔
1213
                                if _, ok := tx.firstInsertedPKs[table.name]; !ok {
38✔
1214
                                        tx.firstInsertedPKs[table.name] = nl
19✔
1215
                                }
19✔
1216
                                tx.lastInsertedPKs[table.name] = nl
19✔
1217
                        }
1218

1219
                        valuesByColID[colID] = rval
8,101✔
1220
                }
1221

1222
                for i, col := range table.cols {
11,735✔
1223
                        v := valuesByColID[col.id]
9,441✔
1224

9,441✔
1225
                        if v == nil {
9,725✔
1226
                                v = NewNull(AnyType)
284✔
1227
                        } else if len(table.checkConstraints) > 0 && col.Type() == JSONType {
9,446✔
1228
                                s, _ := v.RawValue().(string)
5✔
1229
                                jsonVal, err := NewJsonFromString(s)
5✔
1230
                                if err != nil {
5✔
1231
                                        return nil, err
×
1232
                                }
×
1233
                                v = jsonVal
5✔
1234
                        }
1235

1236
                        r.ValuesByPosition[i] = v
9,441✔
1237
                        r.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
9,441✔
1238
                }
1239

1240
                if err := checkConstraints(tx, table.checkConstraints, r, table.name); err != nil {
2,300✔
1241
                        return nil, err
6✔
1242
                }
6✔
1243

1244
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
2,288✔
1245
                if err != nil {
2,293✔
1246
                        return nil, err
5✔
1247
                }
5✔
1248

1249
                // pk entry
1250
                mappedPKey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
2,283✔
1251
                if len(mappedPKey) > MaxKeyLen {
2,283✔
1252
                        return nil, ErrMaxKeyLengthExceeded
×
1253
                }
×
1254

1255
                _, err = tx.get(ctx, mappedPKey)
2,283✔
1256
                if err != nil && !errors.Is(err, store.ErrKeyNotFound) {
2,283✔
1257
                        return nil, err
×
1258
                }
×
1259

1260
                if errors.Is(err, store.ErrKeyNotFound) && pkMustExist {
2,285✔
1261
                        return nil, fmt.Errorf("%w: specified value must be greater than current one", ErrInvalidValue)
2✔
1262
                }
2✔
1263

1264
                if stmt.isInsert {
4,381✔
1265
                        if err == nil && stmt.onConflict == nil {
2,104✔
1266
                                return nil, store.ErrKeyAlreadyExists
4✔
1267
                        }
4✔
1268

1269
                        if err == nil && stmt.onConflict != nil {
2,099✔
1270
                                // TODO: conflict resolution may be extended. Currently only supports "ON CONFLICT DO NOTHING"
3✔
1271
                                continue
3✔
1272
                        }
1273
                }
1274

1275
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, !stmt.isInsert)
2,274✔
1276
                if err != nil {
2,287✔
1277
                        return nil, err
13✔
1278
                }
13✔
1279
        }
1280
        return tx, nil
2,068✔
1281
}
1282

1283
func checkConstraints(tx *SQLTx, checks map[string]CheckConstraint, row *Row, table string) error {
2,326✔
1284
        for _, check := range checks {
2,373✔
1285
                val, err := check.exp.reduce(tx, row, table)
47✔
1286
                if err != nil {
48✔
1287
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, err)
1✔
1288
                }
1✔
1289

1290
                if val.Type() != BooleanType {
46✔
1291
                        return ErrInvalidCheckConstraint
×
1292
                }
×
1293

1294
                if !val.RawValue().(bool) {
53✔
1295
                        return fmt.Errorf("%w: %s", ErrCheckConstraintViolation, check.exp.String())
7✔
1296
                }
7✔
1297
        }
1298
        return nil
2,318✔
1299
}
1300

1301
func (tx *SQLTx) encodeRowValue(valuesByColID map[uint32]TypedValue, table *Table) ([]byte, error) {
2,461✔
1302
        valbuf := bytes.Buffer{}
2,461✔
1303

2,461✔
1304
        // null values are not serialized
2,461✔
1305
        encodedVals := 0
2,461✔
1306
        for _, v := range valuesByColID {
12,035✔
1307
                if !v.IsNull() {
19,130✔
1308
                        encodedVals++
9,556✔
1309
                }
9,556✔
1310
        }
1311

1312
        b := make([]byte, EncLenLen)
2,461✔
1313
        binary.BigEndian.PutUint32(b, uint32(encodedVals))
2,461✔
1314

2,461✔
1315
        _, err := valbuf.Write(b)
2,461✔
1316
        if err != nil {
2,461✔
1317
                return nil, err
×
1318
        }
×
1319

1320
        for _, col := range table.cols {
12,295✔
1321
                rval, specified := valuesByColID[col.id]
9,834✔
1322
                if !specified || rval.IsNull() {
10,118✔
1323
                        continue
284✔
1324
                }
1325

1326
                b := make([]byte, EncIDLen)
9,550✔
1327
                binary.BigEndian.PutUint32(b, uint32(col.id))
9,550✔
1328

9,550✔
1329
                _, err = valbuf.Write(b)
9,550✔
1330
                if err != nil {
9,550✔
1331
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1332
                }
×
1333

1334
                encVal, err := EncodeValue(rval, col.colType, col.MaxLen())
9,550✔
1335
                if err != nil {
9,558✔
1336
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
8✔
1337
                }
8✔
1338

1339
                _, err = valbuf.Write(encVal)
9,542✔
1340
                if err != nil {
9,542✔
1341
                        return nil, fmt.Errorf("%w: table: %s, column: %s", err, table.name, col.colName)
×
1342
                }
×
1343
        }
1344

1345
        return valbuf.Bytes(), nil
2,453✔
1346
}
1347

1348
func (tx *SQLTx) doUpsert(ctx context.Context, pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table, reuseIndex bool) error {
2,304✔
1349
        var reusableIndexEntries map[uint32]struct{}
2,304✔
1350

2,304✔
1351
        if reuseIndex && len(table.indexes) > 1 {
2,361✔
1352
                currPKRow, err := tx.fetchPKRow(ctx, table, valuesByColID)
57✔
1353
                if err == nil {
93✔
1354
                        currValuesByColID := make(map[uint32]TypedValue, len(currPKRow.ValuesBySelector))
36✔
1355

36✔
1356
                        for _, col := range table.cols {
161✔
1357
                                encSel := EncodeSelector("", table.name, col.colName)
125✔
1358
                                currValuesByColID[col.id] = currPKRow.ValuesBySelector[encSel]
125✔
1359
                        }
125✔
1360

1361
                        reusableIndexEntries, err = tx.deprecateIndexEntries(pkEncVals, currValuesByColID, valuesByColID, table)
36✔
1362
                        if err != nil {
36✔
1363
                                return err
×
1364
                        }
×
1365
                } else if !errors.Is(err, ErrNoMoreRows) {
21✔
1366
                        return err
×
1367
                }
×
1368
        }
1369

1370
        rowKey := MapKey(tx.sqlPrefix(), RowPrefix, EncodeID(DatabaseID), EncodeID(table.id), EncodeID(PKIndexID), pkEncVals)
2,304✔
1371

2,304✔
1372
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
2,304✔
1373
        if err != nil {
2,312✔
1374
                return err
8✔
1375
        }
8✔
1376

1377
        err = tx.set(rowKey, nil, encodedRowValue)
2,296✔
1378
        if err != nil {
2,296✔
1379
                return err
×
1380
        }
×
1381

1382
        // create in-memory and validate entries for secondary indexes
1383
        for _, index := range table.indexes {
5,487✔
1384
                if index.IsPrimary() {
5,487✔
1385
                        continue
2,296✔
1386
                }
1387

1388
                if reusableIndexEntries != nil {
972✔
1389
                        _, reusable := reusableIndexEntries[index.id]
77✔
1390
                        if reusable {
127✔
1391
                                continue
50✔
1392
                        }
1393
                }
1394

1395
                encodedValues := make([][]byte, 2+len(index.cols))
845✔
1396
                encodedValues[0] = EncodeID(table.id)
845✔
1397
                encodedValues[1] = EncodeID(index.id)
845✔
1398

845✔
1399
                indexKeyLen := 0
845✔
1400

845✔
1401
                for i, col := range index.cols {
1,765✔
1402
                        rval, specified := valuesByColID[col.id]
920✔
1403
                        if !specified {
993✔
1404
                                rval = &NullValue{t: col.colType}
73✔
1405
                        }
73✔
1406

1407
                        encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
920✔
1408
                        if err != nil {
920✔
1409
                                return fmt.Errorf("%w: index on '%s' and column '%s'", err, index.Name(), col.colName)
×
1410
                        }
×
1411

1412
                        if n > MaxKeyLen {
920✔
1413
                                return fmt.Errorf("%w: can not index entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1414
                        }
×
1415

1416
                        indexKeyLen += n
920✔
1417

920✔
1418
                        encodedValues[i+2] = encVal
920✔
1419
                }
1420

1421
                if indexKeyLen > MaxKeyLen {
845✔
1422
                        return fmt.Errorf("%w: can not index entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1423
                }
×
1424

1425
                smkey := MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...)
845✔
1426

845✔
1427
                // no other equivalent entry should be already indexed
845✔
1428
                if index.IsUnique() {
922✔
1429
                        _, valRef, err := tx.getWithPrefix(ctx, smkey, nil)
77✔
1430
                        if err == nil && (valRef.KVMetadata() == nil || !valRef.KVMetadata().Deleted()) {
82✔
1431
                                return store.ErrKeyAlreadyExists
5✔
1432
                        } else if !errors.Is(err, store.ErrKeyNotFound) {
77✔
1433
                                return err
×
1434
                        }
×
1435
                }
1436

1437
                err = tx.setTransient(smkey, nil, encodedRowValue) // only-indexable
840✔
1438
                if err != nil {
840✔
1439
                        return err
×
1440
                }
×
1441
        }
1442

1443
        tx.updatedRows++
2,291✔
1444

2,291✔
1445
        return nil
2,291✔
1446
}
1447

1448
func encodedKey(index *Index, valuesByColID map[uint32]TypedValue) ([]byte, error) {
13,946✔
1449
        valbuf := bytes.Buffer{}
13,946✔
1450

13,946✔
1451
        indexKeyLen := 0
13,946✔
1452

13,946✔
1453
        for _, col := range index.cols {
27,904✔
1454
                rval, specified := valuesByColID[col.id]
13,958✔
1455
                if !specified || rval.IsNull() {
13,961✔
1456
                        return nil, ErrPKCanNotBeNull
3✔
1457
                }
3✔
1458

1459
                encVal, n, err := EncodeValueAsKey(rval, col.colType, col.MaxLen())
13,955✔
1460
                if err != nil {
13,957✔
1461
                        return nil, fmt.Errorf("%w: index of table '%s' and column '%s'", err, index.table.name, col.colName)
2✔
1462
                }
2✔
1463

1464
                if n > MaxKeyLen {
13,953✔
1465
                        return nil, fmt.Errorf("%w: invalid key entry for column '%s'. Max key length for variable columns is %d", ErrLimitedKeyType, col.colName, MaxKeyLen)
×
1466
                }
×
1467

1468
                indexKeyLen += n
13,953✔
1469

13,953✔
1470
                _, err = valbuf.Write(encVal)
13,953✔
1471
                if err != nil {
13,953✔
1472
                        return nil, err
×
1473
                }
×
1474
        }
1475

1476
        if indexKeyLen > MaxKeyLen {
13,941✔
1477
                return nil, fmt.Errorf("%w: invalid key entry using columns '%v'. Max key length is %d", ErrLimitedKeyType, index.cols, MaxKeyLen)
×
1478
        }
×
1479

1480
        return valbuf.Bytes(), nil
13,941✔
1481
}
1482

1483
func (tx *SQLTx) fetchPKRow(ctx context.Context, table *Table, valuesByColID map[uint32]TypedValue) (*Row, error) {
57✔
1484
        pkRanges := make(map[uint32]*typedValueRange, len(table.primaryIndex.cols))
57✔
1485

57✔
1486
        for _, pkCol := range table.primaryIndex.cols {
114✔
1487
                pkVal := valuesByColID[pkCol.id]
57✔
1488

57✔
1489
                pkRanges[pkCol.id] = &typedValueRange{
57✔
1490
                        lRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1491
                        hRange: &typedValueSemiRange{val: pkVal, inclusive: true},
57✔
1492
                }
57✔
1493
        }
57✔
1494

1495
        scanSpecs := &ScanSpecs{
57✔
1496
                Index:         table.primaryIndex,
57✔
1497
                rangesByColID: pkRanges,
57✔
1498
        }
57✔
1499

57✔
1500
        r, err := newRawRowReader(tx, nil, table, period{}, table.name, scanSpecs)
57✔
1501
        if err != nil {
57✔
1502
                return nil, err
×
1503
        }
×
1504

1505
        defer func() {
114✔
1506
                r.Close()
57✔
1507
        }()
57✔
1508

1509
        return r.Read(ctx)
57✔
1510
}
1511

1512
// deprecateIndexEntries mark previous index entries as deleted
1513
func (tx *SQLTx) deprecateIndexEntries(
1514
        pkEncVals []byte,
1515
        currValuesByColID, newValuesByColID map[uint32]TypedValue,
1516
        table *Table) (reusableIndexEntries map[uint32]struct{}, err error) {
36✔
1517

36✔
1518
        encodedRowValue, err := tx.encodeRowValue(currValuesByColID, table)
36✔
1519
        if err != nil {
36✔
1520
                return nil, err
×
1521
        }
×
1522

1523
        reusableIndexEntries = make(map[uint32]struct{})
36✔
1524

36✔
1525
        for _, index := range table.indexes {
149✔
1526
                if index.IsPrimary() {
149✔
1527
                        continue
36✔
1528
                }
1529

1530
                encodedValues := make([][]byte, 2+len(index.cols)+1)
77✔
1531
                encodedValues[0] = EncodeID(table.id)
77✔
1532
                encodedValues[1] = EncodeID(index.id)
77✔
1533
                encodedValues[len(encodedValues)-1] = pkEncVals
77✔
1534

77✔
1535
                // existent index entry is deleted only if it differs from existent one
77✔
1536
                sameIndexKey := true
77✔
1537

77✔
1538
                for i, col := range index.cols {
159✔
1539
                        currVal, specified := currValuesByColID[col.id]
82✔
1540
                        if !specified {
82✔
1541
                                currVal = &NullValue{t: col.colType}
×
1542
                        }
×
1543

1544
                        newVal, specified := newValuesByColID[col.id]
82✔
1545
                        if !specified {
86✔
1546
                                newVal = &NullValue{t: col.colType}
4✔
1547
                        }
4✔
1548

1549
                        r, err := currVal.Compare(newVal)
82✔
1550
                        if err != nil {
82✔
1551
                                return nil, err
×
1552
                        }
×
1553

1554
                        sameIndexKey = sameIndexKey && r == 0
82✔
1555

82✔
1556
                        encVal, _, _ := EncodeValueAsKey(currVal, col.colType, col.MaxLen())
82✔
1557

82✔
1558
                        encodedValues[i+3] = encVal
82✔
1559
                }
1560

1561
                // mark existent index entry as deleted
1562
                if sameIndexKey {
127✔
1563
                        reusableIndexEntries[index.id] = struct{}{}
50✔
1564
                } else {
77✔
1565
                        md := store.NewKVMetadata()
27✔
1566

27✔
1567
                        md.AsDeleted(true)
27✔
1568

27✔
1569
                        err = tx.set(MapKey(tx.sqlPrefix(), MappedPrefix, encodedValues...), md, encodedRowValue)
27✔
1570
                        if err != nil {
27✔
1571
                                return nil, err
×
1572
                        }
×
1573
                }
1574
        }
1575

1576
        return reusableIndexEntries, nil
36✔
1577
}
1578

1579
type UpdateStmt struct {
1580
        tableRef *tableRef
1581
        where    ValueExp
1582
        updates  []*colUpdate
1583
        indexOn  []string
1584
        limit    ValueExp
1585
        offset   ValueExp
1586
}
1587

1588
type colUpdate struct {
1589
        col string
1590
        op  CmpOperator
1591
        val ValueExp
1592
}
1593

1594
func (stmt *UpdateStmt) readOnly() bool {
4✔
1595
        return false
4✔
1596
}
4✔
1597

1598
func (stmt *UpdateStmt) requiredPrivileges() []SQLPrivilege {
4✔
1599
        return []SQLPrivilege{SQLPrivilegeUpdate}
4✔
1600
}
4✔
1601

1602
func (stmt *UpdateStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1603
        selectStmt := &SelectStmt{
1✔
1604
                ds:    stmt.tableRef,
1✔
1605
                where: stmt.where,
1✔
1606
        }
1✔
1607

1✔
1608
        err := selectStmt.inferParameters(ctx, tx, params)
1✔
1609
        if err != nil {
1✔
1610
                return err
×
1611
        }
×
1612

1613
        table, err := stmt.tableRef.referencedTable(tx)
1✔
1614
        if err != nil {
1✔
1615
                return err
×
1616
        }
×
1617

1618
        for _, update := range stmt.updates {
2✔
1619
                col, err := table.GetColumnByName(update.col)
1✔
1620
                if err != nil {
1✔
1621
                        return err
×
1622
                }
×
1623

1624
                err = update.val.requiresType(col.colType, make(map[string]ColDescriptor), params, table.name)
1✔
1625
                if err != nil {
1✔
1626
                        return err
×
1627
                }
×
1628
        }
1629

1630
        return nil
1✔
1631
}
1632

1633
func (stmt *UpdateStmt) validate(table *Table) error {
21✔
1634
        colIDs := make(map[uint32]struct{}, len(stmt.updates))
21✔
1635

21✔
1636
        for _, update := range stmt.updates {
44✔
1637
                if update.op != EQ {
23✔
1638
                        return ErrIllegalArguments
×
1639
                }
×
1640

1641
                col, err := table.GetColumnByName(update.col)
23✔
1642
                if err != nil {
24✔
1643
                        return err
1✔
1644
                }
1✔
1645

1646
                if table.PrimaryIndex().IncludesCol(col.id) {
22✔
1647
                        return ErrPKCanNotBeUpdated
×
1648
                }
×
1649

1650
                _, duplicated := colIDs[col.id]
22✔
1651
                if duplicated {
22✔
1652
                        return ErrDuplicatedColumn
×
1653
                }
×
1654

1655
                colIDs[col.id] = struct{}{}
22✔
1656
        }
1657

1658
        return nil
20✔
1659
}
1660

1661
func (stmt *UpdateStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
22✔
1662
        selectStmt := &SelectStmt{
22✔
1663
                ds:      stmt.tableRef,
22✔
1664
                where:   stmt.where,
22✔
1665
                indexOn: stmt.indexOn,
22✔
1666
                limit:   stmt.limit,
22✔
1667
                offset:  stmt.offset,
22✔
1668
        }
22✔
1669

22✔
1670
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
22✔
1671
        if err != nil {
23✔
1672
                return nil, err
1✔
1673
        }
1✔
1674
        defer rowReader.Close()
21✔
1675

21✔
1676
        table := rowReader.ScanSpecs().Index.table
21✔
1677

21✔
1678
        err = stmt.validate(table)
21✔
1679
        if err != nil {
22✔
1680
                return nil, err
1✔
1681
        }
1✔
1682

1683
        cols, err := rowReader.colsBySelector(ctx)
20✔
1684
        if err != nil {
20✔
1685
                return nil, err
×
1686
        }
×
1687

1688
        for {
70✔
1689
                row, err := rowReader.Read(ctx)
50✔
1690
                if errors.Is(err, ErrNoMoreRows) {
67✔
1691
                        break
17✔
1692
                } else if err != nil {
34✔
1693
                        return nil, err
1✔
1694
                }
1✔
1695

1696
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
32✔
1697

32✔
1698
                for _, col := range table.cols {
121✔
1699
                        encSel := EncodeSelector("", table.name, col.colName)
89✔
1700
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
89✔
1701
                }
89✔
1702

1703
                for _, update := range stmt.updates {
66✔
1704
                        col, err := table.GetColumnByName(update.col)
34✔
1705
                        if err != nil {
34✔
1706
                                return nil, err
×
1707
                        }
×
1708

1709
                        sval, err := update.val.substitute(params)
34✔
1710
                        if err != nil {
34✔
1711
                                return nil, err
×
1712
                        }
×
1713

1714
                        rval, err := sval.reduce(tx, row, table.name)
34✔
1715
                        if err != nil {
34✔
1716
                                return nil, err
×
1717
                        }
×
1718

1719
                        err = rval.requiresType(col.colType, cols, nil, table.name)
34✔
1720
                        if err != nil {
34✔
1721
                                return nil, err
×
1722
                        }
×
1723

1724
                        valuesByColID[col.id] = rval
34✔
1725
                }
1726

1727
                for i, col := range table.cols {
121✔
1728
                        v := valuesByColID[col.id]
89✔
1729

89✔
1730
                        row.ValuesByPosition[i] = v
89✔
1731
                        row.ValuesBySelector[EncodeSelector("", table.name, col.colName)] = v
89✔
1732
                }
89✔
1733

1734
                if err := checkConstraints(tx, table.checkConstraints, row, table.name); err != nil {
34✔
1735
                        return nil, err
2✔
1736
                }
2✔
1737

1738
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
30✔
1739
                if err != nil {
30✔
1740
                        return nil, err
×
1741
                }
×
1742

1743
                // primary index entry
1744
                mkey := MapKey(tx.sqlPrefix(), MappedPrefix, EncodeID(table.id), EncodeID(table.primaryIndex.id), pkEncVals, pkEncVals)
30✔
1745

30✔
1746
                // mkey must exist
30✔
1747
                _, err = tx.get(ctx, mkey)
30✔
1748
                if err != nil {
30✔
1749
                        return nil, err
×
1750
                }
×
1751

1752
                err = tx.doUpsert(ctx, pkEncVals, valuesByColID, table, true)
30✔
1753
                if err != nil {
30✔
1754
                        return nil, err
×
1755
                }
×
1756
        }
1757

1758
        return tx, nil
17✔
1759
}
1760

1761
type DeleteFromStmt struct {
1762
        tableRef *tableRef
1763
        where    ValueExp
1764
        indexOn  []string
1765
        orderBy  []*OrdExp
1766
        limit    ValueExp
1767
        offset   ValueExp
1768
}
1769

1770
func NewDeleteFromStmt(table string, where ValueExp, orderBy []*OrdExp, limit ValueExp) *DeleteFromStmt {
4✔
1771
        return &DeleteFromStmt{
4✔
1772
                tableRef: NewTableRef(table, ""),
4✔
1773
                where:    where,
4✔
1774
                orderBy:  orderBy,
4✔
1775
                limit:    limit,
4✔
1776
        }
4✔
1777
}
4✔
1778

1779
func (stmt *DeleteFromStmt) readOnly() bool {
1✔
1780
        return false
1✔
1781
}
1✔
1782

1783
func (stmt *DeleteFromStmt) requiredPrivileges() []SQLPrivilege {
1✔
1784
        return []SQLPrivilege{SQLPrivilegeDelete}
1✔
1785
}
1✔
1786

1787
func (stmt *DeleteFromStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
1788
        selectStmt := &SelectStmt{
1✔
1789
                ds:      stmt.tableRef,
1✔
1790
                where:   stmt.where,
1✔
1791
                orderBy: stmt.orderBy,
1✔
1792
        }
1✔
1793
        return selectStmt.inferParameters(ctx, tx, params)
1✔
1794
}
1✔
1795

1796
func (stmt *DeleteFromStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
15✔
1797
        selectStmt := &SelectStmt{
15✔
1798
                ds:      stmt.tableRef,
15✔
1799
                where:   stmt.where,
15✔
1800
                indexOn: stmt.indexOn,
15✔
1801
                orderBy: stmt.orderBy,
15✔
1802
                limit:   stmt.limit,
15✔
1803
                offset:  stmt.offset,
15✔
1804
        }
15✔
1805

15✔
1806
        rowReader, err := selectStmt.Resolve(ctx, tx, params, nil)
15✔
1807
        if err != nil {
17✔
1808
                return nil, err
2✔
1809
        }
2✔
1810
        defer rowReader.Close()
13✔
1811

13✔
1812
        table := rowReader.ScanSpecs().Index.table
13✔
1813

13✔
1814
        for {
147✔
1815
                row, err := rowReader.Read(ctx)
134✔
1816
                if errors.Is(err, ErrNoMoreRows) {
146✔
1817
                        break
12✔
1818
                }
1819
                if err != nil {
123✔
1820
                        return nil, err
1✔
1821
                }
1✔
1822

1823
                valuesByColID := make(map[uint32]TypedValue, len(row.ValuesBySelector))
121✔
1824

121✔
1825
                for _, col := range table.cols {
406✔
1826
                        encSel := EncodeSelector("", table.name, col.colName)
285✔
1827
                        valuesByColID[col.id] = row.ValuesBySelector[encSel]
285✔
1828
                }
285✔
1829

1830
                pkEncVals, err := encodedKey(table.primaryIndex, valuesByColID)
121✔
1831
                if err != nil {
121✔
1832
                        return nil, err
×
1833
                }
×
1834

1835
                err = tx.deleteIndexEntries(pkEncVals, valuesByColID, table)
121✔
1836
                if err != nil {
121✔
1837
                        return nil, err
×
1838
                }
×
1839

1840
                tx.updatedRows++
121✔
1841
        }
1842
        return tx, nil
12✔
1843
}
1844

1845
func (tx *SQLTx) deleteIndexEntries(pkEncVals []byte, valuesByColID map[uint32]TypedValue, table *Table) error {
121✔
1846
        encodedRowValue, err := tx.encodeRowValue(valuesByColID, table)
121✔
1847
        if err != nil {
121✔
1848
                return err
×
1849
        }
×
1850

1851
        for _, index := range table.indexes {
291✔
1852
                if !index.IsPrimary() {
219✔
1853
                        continue
49✔
1854
                }
1855

1856
                encodedValues := make([][]byte, 3+len(index.cols))
121✔
1857
                encodedValues[0] = EncodeID(DatabaseID)
121✔
1858
                encodedValues[1] = EncodeID(table.id)
121✔
1859
                encodedValues[2] = EncodeID(index.id)
121✔
1860

121✔
1861
                for i, col := range index.cols {
242✔
1862
                        val, specified := valuesByColID[col.id]
121✔
1863
                        if !specified {
121✔
1864
                                val = &NullValue{t: col.colType}
×
1865
                        }
×
1866

1867
                        encVal, _, _ := EncodeValueAsKey(val, col.colType, col.MaxLen())
121✔
1868

121✔
1869
                        encodedValues[i+3] = encVal
121✔
1870
                }
1871

1872
                md := store.NewKVMetadata()
121✔
1873

121✔
1874
                md.AsDeleted(true)
121✔
1875

121✔
1876
                err := tx.set(MapKey(tx.sqlPrefix(), RowPrefix, encodedValues...), md, encodedRowValue)
121✔
1877
                if err != nil {
121✔
1878
                        return err
×
1879
                }
×
1880
        }
1881

1882
        return nil
121✔
1883
}
1884

1885
type ValueExp interface {
1886
        inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
1887
        requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
1888
        substitute(params map[string]interface{}) (ValueExp, error)
1889
        selectors() []Selector
1890
        reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error)
1891
        reduceSelectors(row *Row, implicitTable string) ValueExp
1892
        isConstant() bool
1893
        selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error
1894
        String() string
1895
}
1896

1897
type typedValueRange struct {
1898
        lRange *typedValueSemiRange
1899
        hRange *typedValueSemiRange
1900
}
1901

1902
type typedValueSemiRange struct {
1903
        val       TypedValue
1904
        inclusive bool
1905
}
1906

1907
func (r *typedValueRange) unitary() bool {
19✔
1908
        // TODO: this simplified implementation doesn't cover all unitary cases e.g. 3<=v<4
19✔
1909
        if r.lRange == nil || r.hRange == nil {
19✔
1910
                return false
×
1911
        }
×
1912

1913
        res, _ := r.lRange.val.Compare(r.hRange.val)
19✔
1914
        return res == 0 && r.lRange.inclusive && r.hRange.inclusive
19✔
1915
}
1916

1917
func (r *typedValueRange) refineWith(refiningRange *typedValueRange) error {
3✔
1918
        if r.lRange == nil {
4✔
1919
                r.lRange = refiningRange.lRange
1✔
1920
        } else if r.lRange != nil && refiningRange.lRange != nil {
4✔
1921
                maxRange, err := maxSemiRange(r.lRange, refiningRange.lRange)
1✔
1922
                if err != nil {
1✔
1923
                        return err
×
1924
                }
×
1925
                r.lRange = maxRange
1✔
1926
        }
1927

1928
        if r.hRange == nil {
4✔
1929
                r.hRange = refiningRange.hRange
1✔
1930
        } else if r.hRange != nil && refiningRange.hRange != nil {
5✔
1931
                minRange, err := minSemiRange(r.hRange, refiningRange.hRange)
2✔
1932
                if err != nil {
2✔
1933
                        return err
×
1934
                }
×
1935
                r.hRange = minRange
2✔
1936
        }
1937

1938
        return nil
3✔
1939
}
1940

1941
func (r *typedValueRange) extendWith(extendingRange *typedValueRange) error {
5✔
1942
        if r.lRange == nil || extendingRange.lRange == nil {
7✔
1943
                r.lRange = nil
2✔
1944
        } else {
5✔
1945
                minRange, err := minSemiRange(r.lRange, extendingRange.lRange)
3✔
1946
                if err != nil {
3✔
1947
                        return err
×
1948
                }
×
1949
                r.lRange = minRange
3✔
1950
        }
1951

1952
        if r.hRange == nil || extendingRange.hRange == nil {
8✔
1953
                r.hRange = nil
3✔
1954
        } else {
5✔
1955
                maxRange, err := maxSemiRange(r.hRange, extendingRange.hRange)
2✔
1956
                if err != nil {
2✔
1957
                        return err
×
1958
                }
×
1959
                r.hRange = maxRange
2✔
1960
        }
1961

1962
        return nil
5✔
1963
}
1964

1965
func maxSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
3✔
1966
        r, err := or1.val.Compare(or2.val)
3✔
1967
        if err != nil {
3✔
1968
                return nil, err
×
1969
        }
×
1970

1971
        maxVal := or1.val
3✔
1972
        if r < 0 {
5✔
1973
                maxVal = or2.val
2✔
1974
        }
2✔
1975

1976
        return &typedValueSemiRange{
3✔
1977
                val:       maxVal,
3✔
1978
                inclusive: or1.inclusive && or2.inclusive,
3✔
1979
        }, nil
3✔
1980
}
1981

1982
func minSemiRange(or1, or2 *typedValueSemiRange) (*typedValueSemiRange, error) {
5✔
1983
        r, err := or1.val.Compare(or2.val)
5✔
1984
        if err != nil {
5✔
1985
                return nil, err
×
1986
        }
×
1987

1988
        minVal := or1.val
5✔
1989
        if r > 0 {
9✔
1990
                minVal = or2.val
4✔
1991
        }
4✔
1992

1993
        return &typedValueSemiRange{
5✔
1994
                val:       minVal,
5✔
1995
                inclusive: or1.inclusive || or2.inclusive,
5✔
1996
        }, nil
5✔
1997
}
1998

1999
type TypedValue interface {
2000
        ValueExp
2001
        Type() SQLValueType
2002
        RawValue() interface{}
2003
        Compare(val TypedValue) (int, error)
2004
        IsNull() bool
2005
}
2006

2007
type Tuple []TypedValue
2008

2009
func (t Tuple) Compare(other Tuple) (int, int, error) {
204,158✔
2010
        if len(t) != len(other) {
204,158✔
2011
                return -1, -1, ErrNotComparableValues
×
2012
        }
×
2013

2014
        for i := range t {
431,257✔
2015
                res, err := t[i].Compare(other[i])
227,099✔
2016
                if err != nil || res != 0 {
420,947✔
2017
                        return res, i, err
193,848✔
2018
                }
193,848✔
2019
        }
2020
        return 0, -1, nil
10,310✔
2021
}
2022

2023
func NewNull(t SQLValueType) *NullValue {
394✔
2024
        return &NullValue{t: t}
394✔
2025
}
394✔
2026

2027
type NullValue struct {
2028
        t SQLValueType
2029
}
2030

2031
func (n *NullValue) Type() SQLValueType {
108✔
2032
        return n.t
108✔
2033
}
108✔
2034

2035
func (n *NullValue) RawValue() interface{} {
360✔
2036
        return nil
360✔
2037
}
360✔
2038

2039
func (n *NullValue) IsNull() bool {
378✔
2040
        return true
378✔
2041
}
378✔
2042

2043
func (n *NullValue) String() string {
4✔
2044
        return "NULL"
4✔
2045
}
4✔
2046

2047
func (n *NullValue) Compare(val TypedValue) (int, error) {
81✔
2048
        if n.t != AnyType && val.Type() != AnyType && n.t != val.Type() {
82✔
2049
                return 0, ErrNotComparableValues
1✔
2050
        }
1✔
2051

2052
        if val.RawValue() == nil {
121✔
2053
                return 0, nil
41✔
2054
        }
41✔
2055
        return -1, nil
39✔
2056
}
2057

2058
func (v *NullValue) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
7✔
2059
        return v.t, nil
7✔
2060
}
7✔
2061

2062
func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
11✔
2063
        if v.t == t {
16✔
2064
                return nil
5✔
2065
        }
5✔
2066

2067
        if v.t != AnyType {
7✔
2068
                return ErrInvalidTypes
1✔
2069
        }
1✔
2070

2071
        v.t = t
5✔
2072

5✔
2073
        return nil
5✔
2074
}
2075

2076
func (v *NullValue) selectors() []Selector {
13✔
2077
        return nil
13✔
2078
}
13✔
2079

2080
func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) {
394✔
2081
        return v, nil
394✔
2082
}
394✔
2083

2084
func (v *NullValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
349✔
2085
        return v, nil
349✔
2086
}
349✔
2087

2088
func (v *NullValue) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
2089
        return v
10✔
2090
}
10✔
2091

2092
func (v *NullValue) isConstant() bool {
12✔
2093
        return true
12✔
2094
}
12✔
2095

2096
func (v *NullValue) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2097
        return nil
1✔
2098
}
1✔
2099

2100
type Integer struct {
2101
        val int64
2102
}
2103

2104
func NewInteger(val int64) *Integer {
297✔
2105
        return &Integer{val: val}
297✔
2106
}
297✔
2107

2108
func (v *Integer) Type() SQLValueType {
302,780✔
2109
        return IntegerType
302,780✔
2110
}
302,780✔
2111

2112
func (v *Integer) IsNull() bool {
116,514✔
2113
        return false
116,514✔
2114
}
116,514✔
2115

2116
func (v *Integer) String() string {
54✔
2117
        return strconv.FormatInt(v.val, 10)
54✔
2118
}
54✔
2119

2120
func (v *Integer) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
79✔
2121
        return IntegerType, nil
79✔
2122
}
79✔
2123

2124
func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
44✔
2125
        if t != IntegerType && t != JSONType {
48✔
2126
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
4✔
2127
        }
4✔
2128

2129
        return nil
40✔
2130
}
2131

2132
func (v *Integer) selectors() []Selector {
50✔
2133
        return nil
50✔
2134
}
50✔
2135

2136
func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) {
15,104✔
2137
        return v, nil
15,104✔
2138
}
15,104✔
2139

2140
func (v *Integer) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
16,872✔
2141
        return v, nil
16,872✔
2142
}
16,872✔
2143

2144
func (v *Integer) reduceSelectors(row *Row, implicitTable string) ValueExp {
9✔
2145
        return v
9✔
2146
}
9✔
2147

2148
func (v *Integer) isConstant() bool {
116✔
2149
        return true
116✔
2150
}
116✔
2151

2152
func (v *Integer) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2153
        return nil
1✔
2154
}
1✔
2155

2156
func (v *Integer) RawValue() interface{} {
177,127✔
2157
        return v.val
177,127✔
2158
}
177,127✔
2159

2160
func (v *Integer) Compare(val TypedValue) (int, error) {
93,031✔
2161
        if val.IsNull() {
93,074✔
2162
                return 1, nil
43✔
2163
        }
43✔
2164

2165
        if val.Type() == JSONType {
92,989✔
2166
                res, err := val.Compare(v)
1✔
2167
                return -res, err
1✔
2168
        }
1✔
2169

2170
        if val.Type() == Float64Type {
92,987✔
2171
                r, err := val.Compare(v)
×
2172
                return r * -1, err
×
2173
        }
×
2174

2175
        if val.Type() != IntegerType {
92,994✔
2176
                return 0, ErrNotComparableValues
7✔
2177
        }
7✔
2178

2179
        rval := val.RawValue().(int64)
92,980✔
2180

92,980✔
2181
        if v.val == rval {
111,624✔
2182
                return 0, nil
18,644✔
2183
        }
18,644✔
2184

2185
        if v.val > rval {
109,210✔
2186
                return 1, nil
34,874✔
2187
        }
34,874✔
2188

2189
        return -1, nil
39,462✔
2190
}
2191

2192
type Timestamp struct {
2193
        val time.Time
2194
}
2195

2196
func (v *Timestamp) Type() SQLValueType {
40,287✔
2197
        return TimestampType
40,287✔
2198
}
40,287✔
2199

2200
func (v *Timestamp) IsNull() bool {
32,870✔
2201
        return false
32,870✔
2202
}
32,870✔
2203

2204
func (v *Timestamp) String() string {
1✔
2205
        return v.val.Format("2006-01-02 15:04:05.999999")
1✔
2206
}
1✔
2207

2208
func (v *Timestamp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2209
        return TimestampType, nil
1✔
2210
}
1✔
2211

2212
func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
14✔
2213
        if t != TimestampType {
15✔
2214
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
1✔
2215
        }
1✔
2216

2217
        return nil
13✔
2218
}
2219

2220
func (v *Timestamp) selectors() []Selector {
1✔
2221
        return nil
1✔
2222
}
1✔
2223

2224
func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) {
1,163✔
2225
        return v, nil
1,163✔
2226
}
1,163✔
2227

2228
func (v *Timestamp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2,023✔
2229
        return v, nil
2,023✔
2230
}
2,023✔
2231

2232
func (v *Timestamp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2233
        return v
1✔
2234
}
1✔
2235

2236
func (v *Timestamp) isConstant() bool {
1✔
2237
        return true
1✔
2238
}
1✔
2239

2240
func (v *Timestamp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2241
        return nil
1✔
2242
}
1✔
2243

2244
func (v *Timestamp) RawValue() interface{} {
57,384✔
2245
        return v.val
57,384✔
2246
}
57,384✔
2247

2248
func (v *Timestamp) Compare(val TypedValue) (int, error) {
29,744✔
2249
        if val.IsNull() {
29,746✔
2250
                return 1, nil
2✔
2251
        }
2✔
2252

2253
        if val.Type() != TimestampType {
29,743✔
2254
                return 0, ErrNotComparableValues
1✔
2255
        }
1✔
2256

2257
        rval := val.RawValue().(time.Time)
29,741✔
2258

29,741✔
2259
        if v.val.Before(rval) {
44,251✔
2260
                return -1, nil
14,510✔
2261
        }
14,510✔
2262

2263
        if v.val.After(rval) {
30,271✔
2264
                return 1, nil
15,040✔
2265
        }
15,040✔
2266

2267
        return 0, nil
191✔
2268
}
2269

2270
type Varchar struct {
2271
        val string
2272
}
2273

2274
func NewVarchar(val string) *Varchar {
2,051✔
2275
        return &Varchar{val: val}
2,051✔
2276
}
2,051✔
2277

2278
func (v *Varchar) Type() SQLValueType {
127,932✔
2279
        return VarcharType
127,932✔
2280
}
127,932✔
2281

2282
func (v *Varchar) IsNull() bool {
64,567✔
2283
        return false
64,567✔
2284
}
64,567✔
2285

2286
func (v *Varchar) String() string {
18✔
2287
        return fmt.Sprintf("'%s'", v.val)
18✔
2288
}
18✔
2289

2290
func (v *Varchar) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
59✔
2291
        return VarcharType, nil
59✔
2292
}
59✔
2293

2294
func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
140✔
2295
        if t != VarcharType && t != JSONType {
142✔
2296
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
2✔
2297
        }
2✔
2298
        return nil
138✔
2299
}
2300

2301
func (v *Varchar) selectors() []Selector {
32✔
2302
        return nil
32✔
2303
}
32✔
2304

2305
func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) {
4,938✔
2306
        return v, nil
4,938✔
2307
}
4,938✔
2308

2309
func (v *Varchar) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5,627✔
2310
        return v, nil
5,627✔
2311
}
5,627✔
2312

2313
func (v *Varchar) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2314
        return v
×
2315
}
×
2316

2317
func (v *Varchar) isConstant() bool {
39✔
2318
        return true
39✔
2319
}
39✔
2320

2321
func (v *Varchar) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2322
        return nil
1✔
2323
}
1✔
2324

2325
func (v *Varchar) RawValue() interface{} {
90,142✔
2326
        return v.val
90,142✔
2327
}
90,142✔
2328

2329
func (v *Varchar) Compare(val TypedValue) (int, error) {
58,122✔
2330
        if val.IsNull() {
58,178✔
2331
                return 1, nil
56✔
2332
        }
56✔
2333

2334
        if val.Type() == JSONType {
59,067✔
2335
                res, err := val.Compare(v)
1,001✔
2336
                return -res, err
1,001✔
2337
        }
1,001✔
2338

2339
        if val.Type() != VarcharType {
57,066✔
2340
                return 0, ErrNotComparableValues
1✔
2341
        }
1✔
2342

2343
        rval := val.RawValue().(string)
57,064✔
2344

57,064✔
2345
        return bytes.Compare([]byte(v.val), []byte(rval)), nil
57,064✔
2346
}
2347

2348
type UUID struct {
2349
        val uuid.UUID
2350
}
2351

2352
func NewUUID(val uuid.UUID) *UUID {
1✔
2353
        return &UUID{val: val}
1✔
2354
}
1✔
2355

2356
func (v *UUID) Type() SQLValueType {
10✔
2357
        return UUIDType
10✔
2358
}
10✔
2359

2360
func (v *UUID) IsNull() bool {
26✔
2361
        return false
26✔
2362
}
26✔
2363

2364
func (v *UUID) String() string {
1✔
2365
        return v.val.String()
1✔
2366
}
1✔
2367

2368
func (v *UUID) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2369
        return UUIDType, nil
1✔
2370
}
1✔
2371

2372
func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
4✔
2373
        if t != UUIDType {
6✔
2374
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
2✔
2375
        }
2✔
2376

2377
        return nil
2✔
2378
}
2379

2380
func (v *UUID) selectors() []Selector {
1✔
2381
        return nil
1✔
2382
}
1✔
2383

2384
func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) {
6✔
2385
        return v, nil
6✔
2386
}
6✔
2387

2388
func (v *UUID) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
5✔
2389
        return v, nil
5✔
2390
}
5✔
2391

2392
func (v *UUID) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2393
        return v
1✔
2394
}
1✔
2395

2396
func (v *UUID) isConstant() bool {
1✔
2397
        return true
1✔
2398
}
1✔
2399

2400
func (v *UUID) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2401
        return nil
1✔
2402
}
1✔
2403

2404
func (v *UUID) RawValue() interface{} {
41✔
2405
        return v.val
41✔
2406
}
41✔
2407

2408
func (v *UUID) Compare(val TypedValue) (int, error) {
5✔
2409
        if val.IsNull() {
7✔
2410
                return 1, nil
2✔
2411
        }
2✔
2412

2413
        if val.Type() != UUIDType {
4✔
2414
                return 0, ErrNotComparableValues
1✔
2415
        }
1✔
2416

2417
        rval := val.RawValue().(uuid.UUID)
2✔
2418

2✔
2419
        return bytes.Compare(v.val[:], rval[:]), nil
2✔
2420
}
2421

2422
type Bool struct {
2423
        val bool
2424
}
2425

2426
func NewBool(val bool) *Bool {
208✔
2427
        return &Bool{val: val}
208✔
2428
}
208✔
2429

2430
func (v *Bool) Type() SQLValueType {
1,969✔
2431
        return BooleanType
1,969✔
2432
}
1,969✔
2433

2434
func (v *Bool) IsNull() bool {
1,479✔
2435
        return false
1,479✔
2436
}
1,479✔
2437

2438
func (v *Bool) String() string {
41✔
2439
        return strconv.FormatBool(v.val)
41✔
2440
}
41✔
2441

2442
func (v *Bool) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
31✔
2443
        return BooleanType, nil
31✔
2444
}
31✔
2445

2446
func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
55✔
2447
        if t != BooleanType && t != JSONType {
60✔
2448
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
5✔
2449
        }
5✔
2450
        return nil
50✔
2451
}
2452

2453
func (v *Bool) selectors() []Selector {
4✔
2454
        return nil
4✔
2455
}
4✔
2456

2457
func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) {
636✔
2458
        return v, nil
636✔
2459
}
636✔
2460

2461
func (v *Bool) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
724✔
2462
        return v, nil
724✔
2463
}
724✔
2464

2465
func (v *Bool) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2466
        return v
×
2467
}
×
2468

2469
func (v *Bool) isConstant() bool {
3✔
2470
        return true
3✔
2471
}
3✔
2472

2473
func (v *Bool) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
2474
        return nil
7✔
2475
}
7✔
2476

2477
func (v *Bool) RawValue() interface{} {
1,715✔
2478
        return v.val
1,715✔
2479
}
1,715✔
2480

2481
func (v *Bool) Compare(val TypedValue) (int, error) {
575✔
2482
        if val.IsNull() {
605✔
2483
                return 1, nil
30✔
2484
        }
30✔
2485

2486
        if val.Type() == JSONType {
546✔
2487
                res, err := val.Compare(v)
1✔
2488
                return -res, err
1✔
2489
        }
1✔
2490

2491
        if val.Type() != BooleanType {
544✔
2492
                return 0, ErrNotComparableValues
×
2493
        }
×
2494

2495
        rval := val.RawValue().(bool)
544✔
2496

544✔
2497
        if v.val == rval {
888✔
2498
                return 0, nil
344✔
2499
        }
344✔
2500

2501
        if v.val {
206✔
2502
                return 1, nil
6✔
2503
        }
6✔
2504

2505
        return -1, nil
194✔
2506
}
2507

2508
type Blob struct {
2509
        val []byte
2510
}
2511

2512
func NewBlob(val []byte) *Blob {
286✔
2513
        return &Blob{val: val}
286✔
2514
}
286✔
2515

2516
func (v *Blob) Type() SQLValueType {
53✔
2517
        return BLOBType
53✔
2518
}
53✔
2519

2520
func (v *Blob) IsNull() bool {
2,312✔
2521
        return false
2,312✔
2522
}
2,312✔
2523

2524
func (v *Blob) String() string {
2✔
2525
        return hex.EncodeToString(v.val)
2✔
2526
}
2✔
2527

2528
func (v *Blob) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
2529
        return BLOBType, nil
1✔
2530
}
1✔
2531

2532
func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
2533
        if t != BLOBType {
3✔
2534
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BLOBType, t)
1✔
2535
        }
1✔
2536

2537
        return nil
1✔
2538
}
2539

2540
func (v *Blob) selectors() []Selector {
1✔
2541
        return nil
1✔
2542
}
1✔
2543

2544
func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) {
714✔
2545
        return v, nil
714✔
2546
}
714✔
2547

2548
func (v *Blob) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
726✔
2549
        return v, nil
726✔
2550
}
726✔
2551

2552
func (v *Blob) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2553
        return v
×
2554
}
×
2555

2556
func (v *Blob) isConstant() bool {
7✔
2557
        return true
7✔
2558
}
7✔
2559

2560
func (v *Blob) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2561
        return nil
×
2562
}
×
2563

2564
func (v *Blob) RawValue() interface{} {
2,570✔
2565
        return v.val
2,570✔
2566
}
2,570✔
2567

2568
func (v *Blob) Compare(val TypedValue) (int, error) {
25✔
2569
        if val.IsNull() {
25✔
2570
                return 1, nil
×
2571
        }
×
2572

2573
        if val.Type() != BLOBType {
25✔
2574
                return 0, ErrNotComparableValues
×
2575
        }
×
2576

2577
        rval := val.RawValue().([]byte)
25✔
2578

25✔
2579
        return bytes.Compare(v.val, rval), nil
25✔
2580
}
2581

2582
type Float64 struct {
2583
        val float64
2584
}
2585

2586
func NewFloat64(val float64) *Float64 {
1,249✔
2587
        return &Float64{val: val}
1,249✔
2588
}
1,249✔
2589

2590
func (v *Float64) Type() SQLValueType {
207,440✔
2591
        return Float64Type
207,440✔
2592
}
207,440✔
2593

2594
func (v *Float64) IsNull() bool {
5,598✔
2595
        return false
5,598✔
2596
}
5,598✔
2597

2598
func (v *Float64) String() string {
3✔
2599
        return strconv.FormatFloat(float64(v.val), 'f', -1, 64)
3✔
2600
}
3✔
2601

2602
func (v *Float64) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
15✔
2603
        return Float64Type, nil
15✔
2604
}
15✔
2605

2606
func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
20✔
2607
        if t != Float64Type && t != JSONType {
21✔
2608
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, Float64Type, t)
1✔
2609
        }
1✔
2610
        return nil
19✔
2611
}
2612

2613
func (v *Float64) selectors() []Selector {
2✔
2614
        return nil
2✔
2615
}
2✔
2616

2617
func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) {
1,817✔
2618
        return v, nil
1,817✔
2619
}
1,817✔
2620

2621
func (v *Float64) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
3,428✔
2622
        return v, nil
3,428✔
2623
}
3,428✔
2624

2625
func (v *Float64) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
2626
        return v
1✔
2627
}
1✔
2628

2629
func (v *Float64) isConstant() bool {
5✔
2630
        return true
5✔
2631
}
5✔
2632

2633
func (v *Float64) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
2634
        return nil
1✔
2635
}
1✔
2636

2637
func (v *Float64) RawValue() interface{} {
365,970✔
2638
        return v.val
365,970✔
2639
}
365,970✔
2640

2641
func (v *Float64) Compare(val TypedValue) (int, error) {
61,875✔
2642
        if val.Type() == JSONType {
61,876✔
2643
                res, err := val.Compare(v)
1✔
2644
                return -res, err
1✔
2645
        }
1✔
2646

2647
        convVal, err := mayApplyImplicitConversion(val.RawValue(), Float64Type)
61,874✔
2648
        if err != nil {
61,875✔
2649
                return 0, err
1✔
2650
        }
1✔
2651

2652
        if convVal == nil {
61,876✔
2653
                return 1, nil
3✔
2654
        }
3✔
2655

2656
        rval, ok := convVal.(float64)
61,870✔
2657
        if !ok {
61,870✔
2658
                return 0, ErrNotComparableValues
×
2659
        }
×
2660

2661
        if v.val == rval {
61,995✔
2662
                return 0, nil
125✔
2663
        }
125✔
2664

2665
        if v.val > rval {
90,617✔
2666
                return 1, nil
28,872✔
2667
        }
28,872✔
2668

2669
        return -1, nil
32,873✔
2670
}
2671

2672
type FnCall struct {
2673
        fn     string
2674
        params []ValueExp
2675
}
2676

2677
func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2678
        fn, err := v.resolveFunc()
23✔
2679
        if err != nil {
24✔
2680
                return AnyType, nil
1✔
2681
        }
1✔
2682
        return fn.InferType(cols, params, implicitTable)
22✔
2683
}
2684

2685
func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
18✔
2686
        fn, err := v.resolveFunc()
18✔
2687
        if err != nil {
19✔
2688
                return err
1✔
2689
        }
1✔
2690
        return fn.RequiresType(t, cols, params, implicitTable)
17✔
2691
}
2692

2693
func (v *FnCall) selectors() []Selector {
33✔
2694
        selectors := make([]Selector, 0)
33✔
2695
        for _, param := range v.params {
89✔
2696
                selectors = append(selectors, param.selectors()...)
56✔
2697
        }
56✔
2698
        return selectors
33✔
2699
}
2700

2701
func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) {
433✔
2702
        ps := make([]ValueExp, len(v.params))
433✔
2703
        for i, p := range v.params {
791✔
2704
                ps[i], err = p.substitute(params)
358✔
2705
                if err != nil {
358✔
2706
                        return nil, err
×
2707
                }
×
2708
        }
2709

2710
        return &FnCall{
433✔
2711
                fn:     v.fn,
433✔
2712
                params: ps,
433✔
2713
        }, nil
433✔
2714
}
2715

2716
func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
433✔
2717
        fn, err := v.resolveFunc()
433✔
2718
        if err != nil {
434✔
2719
                return nil, err
1✔
2720
        }
1✔
2721

2722
        fnInputs, err := v.reduceParams(tx, row, implicitTable)
432✔
2723
        if err != nil {
432✔
2724
                return nil, err
×
2725
        }
×
2726
        return fn.Apply(tx, fnInputs)
432✔
2727
}
2728

2729
func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) {
432✔
2730
        var values []TypedValue
432✔
2731
        if len(v.params) > 0 {
763✔
2732
                values = make([]TypedValue, len(v.params))
331✔
2733
                for i, p := range v.params {
689✔
2734
                        v, err := p.reduce(tx, row, implicitTable)
358✔
2735
                        if err != nil {
358✔
2736
                                return nil, err
×
2737
                        }
×
2738
                        values[i] = v
358✔
2739
                }
2740
        }
2741
        return values, nil
432✔
2742
}
2743

2744
func (v *FnCall) resolveFunc() (Function, error) {
474✔
2745
        fn, exists := builtinFunctions[strings.ToUpper(v.fn)]
474✔
2746
        if !exists {
477✔
2747
                return nil, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn)
3✔
2748
        }
3✔
2749
        return fn, nil
471✔
2750
}
2751

2752
func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2753
        return v
×
2754
}
×
2755

2756
func (v *FnCall) isConstant() bool {
13✔
2757
        return false
13✔
2758
}
13✔
2759

2760
func (v *FnCall) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2761
        return nil
×
2762
}
×
2763

2764
func (v *FnCall) String() string {
1✔
2765
        params := make([]string, len(v.params))
1✔
2766
        for i, p := range v.params {
4✔
2767
                params[i] = p.String()
3✔
2768
        }
3✔
2769
        return v.fn + "(" + strings.Join(params, ",") + ")"
1✔
2770
}
2771

2772
type Cast struct {
2773
        val ValueExp
2774
        t   SQLValueType
2775
}
2776

2777
func (c *Cast) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
23✔
2778
        _, err := c.val.inferType(cols, params, implicitTable)
23✔
2779
        if err != nil {
24✔
2780
                return AnyType, err
1✔
2781
        }
1✔
2782

2783
        // val type may be restricted by compatible conversions, but multiple types may be compatible...
2784

2785
        return c.t, nil
22✔
2786
}
2787

2788
func (c *Cast) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
×
2789
        if c.t != t {
×
2790
                return fmt.Errorf("%w: can not use value cast to %s as %s", ErrInvalidTypes, c.t, t)
×
2791
        }
×
2792

2793
        return nil
×
2794
}
2795

2796
func (c *Cast) substitute(params map[string]interface{}) (ValueExp, error) {
277✔
2797
        val, err := c.val.substitute(params)
277✔
2798
        if err != nil {
277✔
2799
                return nil, err
×
2800
        }
×
2801
        c.val = val
277✔
2802
        return c, nil
277✔
2803
}
2804

2805
func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
265✔
2806
        val, err := c.val.reduce(tx, row, implicitTable)
265✔
2807
        if err != nil {
265✔
2808
                return nil, err
×
2809
        }
×
2810

2811
        conv, err := getConverter(val.Type(), c.t)
265✔
2812
        if conv == nil {
268✔
2813
                return nil, err
3✔
2814
        }
3✔
2815

2816
        return conv(val)
262✔
2817
}
2818

2819
func (v *Cast) selectors() []Selector {
4✔
2820
        return v.val.selectors()
4✔
2821
}
4✔
2822

2823
func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2824
        return &Cast{
×
2825
                val: c.val.reduceSelectors(row, implicitTable),
×
2826
                t:   c.t,
×
2827
        }
×
2828
}
×
2829

2830
func (c *Cast) isConstant() bool {
7✔
2831
        return c.val.isConstant()
7✔
2832
}
7✔
2833

2834
func (c *Cast) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
2835
        return nil
×
2836
}
×
2837

2838
func (c *Cast) String() string {
1✔
2839
        return fmt.Sprintf("CAST (%s AS %s)", c.val.String(), c.t)
1✔
2840
}
1✔
2841

2842
type Param struct {
2843
        id  string
2844
        pos int
2845
}
2846

2847
func (v *Param) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
59✔
2848
        t, ok := params[v.id]
59✔
2849
        if !ok {
116✔
2850
                params[v.id] = AnyType
57✔
2851
                return AnyType, nil
57✔
2852
        }
57✔
2853

2854
        return t, nil
2✔
2855
}
2856

2857
func (v *Param) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
76✔
2858
        currT, ok := params[v.id]
76✔
2859
        if ok && currT != t && currT != AnyType {
80✔
2860
                return ErrInferredMultipleTypes
4✔
2861
        }
4✔
2862

2863
        params[v.id] = t
72✔
2864

72✔
2865
        return nil
72✔
2866
}
2867

2868
func (p *Param) substitute(params map[string]interface{}) (ValueExp, error) {
6,399✔
2869
        val, ok := params[p.id]
6,399✔
2870
        if !ok {
6,461✔
2871
                return nil, fmt.Errorf("%w(%s)", ErrMissingParameter, p.id)
62✔
2872
        }
62✔
2873

2874
        if val == nil {
6,386✔
2875
                return &NullValue{t: AnyType}, nil
49✔
2876
        }
49✔
2877

2878
        switch v := val.(type) {
6,288✔
2879
        case bool:
96✔
2880
                {
192✔
2881
                        return &Bool{val: v}, nil
96✔
2882
                }
96✔
2883
        case string:
1,752✔
2884
                {
3,504✔
2885
                        return &Varchar{val: v}, nil
1,752✔
2886
                }
1,752✔
2887
        case int:
1,678✔
2888
                {
3,356✔
2889
                        return &Integer{val: int64(v)}, nil
1,678✔
2890
                }
1,678✔
2891
        case uint:
×
2892
                {
×
2893
                        return &Integer{val: int64(v)}, nil
×
2894
                }
×
2895
        case uint64:
34✔
2896
                {
68✔
2897
                        return &Integer{val: int64(v)}, nil
34✔
2898
                }
34✔
2899
        case int64:
227✔
2900
                {
454✔
2901
                        return &Integer{val: v}, nil
227✔
2902
                }
227✔
2903
        case []byte:
14✔
2904
                {
28✔
2905
                        return &Blob{val: v}, nil
14✔
2906
                }
14✔
2907
        case time.Time:
861✔
2908
                {
1,722✔
2909
                        return &Timestamp{val: v.Truncate(time.Microsecond).UTC()}, nil
861✔
2910
                }
861✔
2911
        case float64:
1,625✔
2912
                {
3,250✔
2913
                        return &Float64{val: v}, nil
1,625✔
2914
                }
1,625✔
2915
        }
2916
        return nil, ErrUnsupportedParameter
1✔
2917
}
2918

2919
func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
×
2920
        return nil, ErrUnexpected
×
2921
}
×
2922

2923
func (p *Param) selectors() []Selector {
4✔
2924
        return nil
4✔
2925
}
4✔
2926

2927
func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
2928
        return p
×
2929
}
×
2930

2931
func (p *Param) isConstant() bool {
130✔
2932
        return true
130✔
2933
}
130✔
2934

2935
func (v *Param) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
5✔
2936
        return nil
5✔
2937
}
5✔
2938

2939
func (v *Param) String() string {
2✔
2940
        return "@" + v.id
2✔
2941
}
2✔
2942

2943
type whenThenClause struct {
2944
        when, then ValueExp
2945
}
2946

2947
type CaseWhenExp struct {
2948
        exp      ValueExp
2949
        whenThen []whenThenClause
2950
        elseExp  ValueExp
2951
}
2952

2953
func (ce *CaseWhenExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
8✔
2954
        checkType := func(e ValueExp, expectedType SQLValueType) (string, error) {
18✔
2955
                t, err := e.inferType(cols, params, implicitTable)
10✔
2956
                if err != nil {
10✔
2957
                        return "", err
×
2958
                }
×
2959

2960
                if expectedType == AnyType {
15✔
2961
                        return t, nil
5✔
2962
                }
5✔
2963

2964
                if t != expectedType {
6✔
2965
                        if (t == Float64Type && expectedType == IntegerType) ||
1✔
2966
                                (t == IntegerType && expectedType == Float64Type) {
1✔
2967
                                return Float64Type, nil
×
2968
                        }
×
2969
                        return "", fmt.Errorf("%w: CASE types %s and %s cannot be matched", ErrInferredMultipleTypes, expectedType, t)
1✔
2970
                }
2971
                return t, nil
4✔
2972
        }
2973

2974
        searchType := BooleanType
8✔
2975
        inferredResType := AnyType
8✔
2976
        if ce.exp != nil {
11✔
2977
                t, err := ce.exp.inferType(cols, params, implicitTable)
3✔
2978
                if err != nil {
3✔
2979
                        return "", err
×
2980
                }
×
2981
                searchType = t
3✔
2982
        }
2983

2984
        for _, e := range ce.whenThen {
16✔
2985
                whenType, err := e.when.inferType(cols, params, implicitTable)
8✔
2986
                if err != nil {
8✔
2987
                        return "", err
×
2988
                }
×
2989

2990
                if whenType != searchType {
11✔
2991
                        return "", fmt.Errorf("%w: argument of CASE/WHEN must be of type %s, not type %s", ErrInvalidTypes, searchType, whenType)
3✔
2992
                }
3✔
2993

2994
                t, err := checkType(e.then, inferredResType)
5✔
2995
                if err != nil {
5✔
2996
                        return "", err
×
2997
                }
×
2998
                inferredResType = t
5✔
2999
        }
3000

3001
        if ce.elseExp != nil {
10✔
3002
                return checkType(ce.elseExp, inferredResType)
5✔
3003
        }
5✔
NEW
3004
        return inferredResType, nil
×
3005
}
3006

3007
func (ce *CaseWhenExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
3008
        inferredType, err := ce.inferType(cols, params, implicitTable)
2✔
3009
        if err != nil {
3✔
3010
                return err
1✔
3011
        }
1✔
3012

3013
        if inferredType != t {
1✔
3014
                return fmt.Errorf("%w: expected type %s but %s found instead", ErrInvalidTypes, t, inferredType)
×
3015
        }
×
3016
        return nil
1✔
3017
}
3018

3019
func (ce *CaseWhenExp) substitute(params map[string]interface{}) (ValueExp, error) {
504✔
3020
        var exp ValueExp
504✔
3021
        if ce.exp != nil {
605✔
3022
                e, err := ce.exp.substitute(params)
101✔
3023
                if err != nil {
101✔
3024
                        return nil, err
×
3025
                }
×
3026
                exp = e
101✔
3027
        }
3028

3029
        whenThen := make([]whenThenClause, len(ce.whenThen))
504✔
3030
        for i, wt := range ce.whenThen {
1,208✔
3031
                whenValue, err := wt.when.substitute(params)
704✔
3032
                if err != nil {
704✔
3033
                        return nil, err
×
3034
                }
×
3035
                whenThen[i].when = whenValue
704✔
3036

704✔
3037
                thenValue, err := wt.then.substitute(params)
704✔
3038
                if err != nil {
704✔
3039
                        return nil, err
×
3040
                }
×
3041
                whenThen[i].then = thenValue
704✔
3042
        }
3043

3044
        if ce.elseExp == nil {
506✔
3045
                return &CaseWhenExp{
2✔
3046
                        exp:      exp,
2✔
3047
                        whenThen: whenThen,
2✔
3048
                }, nil
2✔
3049
        }
2✔
3050

3051
        elseValue, err := ce.elseExp.substitute(params)
502✔
3052
        if err != nil {
502✔
3053
                return nil, err
×
3054
        }
×
3055
        return &CaseWhenExp{
502✔
3056
                exp:      exp,
502✔
3057
                whenThen: whenThen,
502✔
3058
                elseExp:  elseValue,
502✔
3059
        }, nil
502✔
3060
}
3061

3062
func (ce *CaseWhenExp) selectors() []Selector {
7✔
3063
        selectors := make([]Selector, 0)
7✔
3064
        if ce.exp != nil {
8✔
3065
                selectors = append(selectors, ce.exp.selectors()...)
1✔
3066
        }
1✔
3067

3068
        for _, wh := range ce.whenThen {
16✔
3069
                selectors = append(selectors, wh.when.selectors()...)
9✔
3070
                selectors = append(selectors, wh.then.selectors()...)
9✔
3071
        }
9✔
3072

3073
        if ce.elseExp == nil {
9✔
3074
                return selectors
2✔
3075
        }
2✔
3076
        return append(selectors, ce.elseExp.selectors()...)
5✔
3077
}
3078

3079
func (ce *CaseWhenExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
304✔
3080
        var searchValue TypedValue
304✔
3081
        if ce.exp != nil {
405✔
3082
                v, err := ce.exp.reduce(tx, row, implicitTable)
101✔
3083
                if err != nil {
101✔
3084
                        return nil, err
×
3085
                }
×
3086
                searchValue = v
101✔
3087
        } else {
203✔
3088
                searchValue = &Bool{val: true}
203✔
3089
        }
203✔
3090

3091
        for _, wt := range ce.whenThen {
734✔
3092
                v, err := wt.when.reduce(tx, row, implicitTable)
430✔
3093
                if err != nil {
430✔
3094
                        return nil, err
×
3095
                }
×
3096

3097
                if v.Type() != searchValue.Type() {
431✔
3098
                        return nil, fmt.Errorf("%w: argument of CASE/WHEN must be type %s, not type %s", ErrInvalidTypes, v.Type(), searchValue.Type())
1✔
3099
                }
1✔
3100

3101
                res, err := v.Compare(searchValue)
429✔
3102
                if err != nil {
429✔
3103
                        return nil, err
×
3104
                }
×
3105
                if res == 0 {
629✔
3106
                        return wt.then.reduce(tx, row, implicitTable)
200✔
3107
                }
200✔
3108
        }
3109

3110
        if ce.elseExp == nil {
104✔
3111
                return NewNull(AnyType), nil
1✔
3112
        }
1✔
3113
        return ce.elseExp.reduce(tx, row, implicitTable)
102✔
3114
}
3115

3116
func (ce *CaseWhenExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
3117
        whenThen := make([]whenThenClause, len(ce.whenThen))
1✔
3118
        for i, wt := range ce.whenThen {
2✔
3119
                whenValue := wt.when.reduceSelectors(row, implicitTable)
1✔
3120
                whenThen[i].when = whenValue
1✔
3121

1✔
3122
                thenValue := wt.then.reduceSelectors(row, implicitTable)
1✔
3123
                whenThen[i].then = thenValue
1✔
3124
        }
1✔
3125

3126
        if ce.elseExp == nil {
1✔
3127
                return &CaseWhenExp{
×
3128
                        whenThen: whenThen,
×
3129
                }
×
3130
        }
×
3131

3132
        return &CaseWhenExp{
1✔
3133
                whenThen: whenThen,
1✔
3134
                elseExp:  ce.elseExp.reduceSelectors(row, implicitTable),
1✔
3135
        }
1✔
3136
}
3137

3138
func (ce *CaseWhenExp) isConstant() bool {
1✔
3139
        return false
1✔
3140
}
1✔
3141

3142
func (ce *CaseWhenExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
3143
        return nil
1✔
3144
}
1✔
3145

3146
func (ce *CaseWhenExp) String() string {
3✔
3147
        var sb strings.Builder
3✔
3148
        for _, wh := range ce.whenThen {
7✔
3149
                sb.WriteString(fmt.Sprintf("WHEN %s THEN %s ", wh.when.String(), wh.then.String()))
4✔
3150
        }
4✔
3151

3152
        if ce.elseExp != nil {
5✔
3153
                sb.WriteString("ELSE " + ce.elseExp.String() + " ")
2✔
3154
        }
2✔
3155
        return "CASE " + sb.String() + "END"
3✔
3156
}
3157

3158
type Comparison int
3159

3160
const (
3161
        EqualTo Comparison = iota
3162
        LowerThan
3163
        LowerOrEqualTo
3164
        GreaterThan
3165
        GreaterOrEqualTo
3166
)
3167

3168
type DataSource interface {
3169
        SQLStmt
3170
        Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error)
3171
        Alias() string
3172
}
3173

3174
type TargetEntry struct {
3175
        Exp ValueExp
3176
        As  string
3177
}
3178

3179
type SelectStmt struct {
3180
        distinct  bool
3181
        targets   []TargetEntry
3182
        selectors []Selector
3183
        ds        DataSource
3184
        indexOn   []string
3185
        joins     []*JoinSpec
3186
        where     ValueExp
3187
        groupBy   []*ColSelector
3188
        having    ValueExp
3189
        orderBy   []*OrdExp
3190
        limit     ValueExp
3191
        offset    ValueExp
3192
        as        string
3193
}
3194

3195
func NewSelectStmt(
3196
        targets []TargetEntry,
3197
        ds DataSource,
3198
        where ValueExp,
3199
        orderBy []*OrdExp,
3200
        limit ValueExp,
3201
        offset ValueExp,
3202
) *SelectStmt {
71✔
3203
        return &SelectStmt{
71✔
3204
                targets: targets,
71✔
3205
                ds:      ds,
71✔
3206
                where:   where,
71✔
3207
                orderBy: orderBy,
71✔
3208
                limit:   limit,
71✔
3209
                offset:  offset,
71✔
3210
        }
71✔
3211
}
71✔
3212

3213
func (stmt *SelectStmt) readOnly() bool {
94✔
3214
        return true
94✔
3215
}
94✔
3216

3217
func (stmt *SelectStmt) requiredPrivileges() []SQLPrivilege {
96✔
3218
        return []SQLPrivilege{SQLPrivilegeSelect}
96✔
3219
}
96✔
3220

3221
func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
53✔
3222
        _, err := stmt.execAt(ctx, tx, nil)
53✔
3223
        if err != nil {
53✔
3224
                return err
×
3225
        }
×
3226

3227
        // TODO: (jeroiraz) may be optimized so to resolve the query statement just once
3228
        rowReader, err := stmt.Resolve(ctx, tx, nil, nil)
53✔
3229
        if err != nil {
54✔
3230
                return err
1✔
3231
        }
1✔
3232
        defer rowReader.Close()
52✔
3233

52✔
3234
        return rowReader.InferParameters(ctx, params)
52✔
3235
}
3236

3237
func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
612✔
3238
        if stmt.groupBy == nil && stmt.having != nil {
613✔
3239
                return nil, ErrHavingClauseRequiresGroupClause
1✔
3240
        }
1✔
3241

3242
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
696✔
3243
                for _, sel := range stmt.targetSelectors() {
227✔
3244
                        _, isAgg := sel.(*AggColSelector)
142✔
3245
                        if !isAgg && !stmt.groupByContains(sel) {
144✔
3246
                                return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3247
                        }
2✔
3248
                }
3249
        }
3250

3251
        if len(stmt.orderBy) > 0 {
763✔
3252
                for _, col := range stmt.orderBy {
346✔
3253
                        for _, sel := range col.exp.selectors() {
374✔
3254
                                _, isAgg := sel.(*AggColSelector)
182✔
3255
                                if (isAgg && !stmt.selectorAppearsInTargets(sel)) || (!isAgg && len(stmt.groupBy) > 0 && !stmt.groupByContains(sel)) {
184✔
3256
                                        return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation)
2✔
3257
                                }
2✔
3258
                        }
3259
                }
3260
        }
3261
        return tx, nil
607✔
3262
}
3263

3264
func (stmt *SelectStmt) targetSelectors() []Selector {
2,503✔
3265
        if stmt.selectors == nil {
3,446✔
3266
                stmt.selectors = stmt.extractSelectors()
943✔
3267
        }
943✔
3268
        return stmt.selectors
2,503✔
3269
}
3270

3271
func (stmt *SelectStmt) selectorAppearsInTargets(s Selector) bool {
4✔
3272
        encSel := EncodeSelector(s.resolve(stmt.Alias()))
4✔
3273

4✔
3274
        for _, sel := range stmt.targetSelectors() {
12✔
3275
                if EncodeSelector(sel.resolve(stmt.Alias())) == encSel {
11✔
3276
                        return true
3✔
3277
                }
3✔
3278
        }
3279
        return false
1✔
3280
}
3281

3282
func (stmt *SelectStmt) groupByContains(sel Selector) bool {
57✔
3283
        encSel := EncodeSelector(sel.resolve(stmt.Alias()))
57✔
3284

57✔
3285
        for _, colSel := range stmt.groupBy {
137✔
3286
                if EncodeSelector(colSel.resolve(stmt.Alias())) == encSel {
134✔
3287
                        return true
54✔
3288
                }
54✔
3289
        }
3290
        return false
3✔
3291
}
3292

3293
func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector {
80✔
3294
        cols := make([]*AggColSelector, 0, len(stmt.targets))
80✔
3295

80✔
3296
        for _, t := range stmt.targets {
214✔
3297
                selectors := t.Exp.selectors()
134✔
3298
                for _, sel := range selectors {
268✔
3299
                        aggSel, isAgg := sel.(*AggColSelector)
134✔
3300
                        if isAgg {
244✔
3301
                                cols = append(cols, aggSel)
110✔
3302
                        }
110✔
3303
                }
3304
        }
3305
        return cols
80✔
3306
}
3307

3308
func (stmt *SelectStmt) extractSelectors() []Selector {
943✔
3309
        selectors := make([]Selector, 0, len(stmt.targets))
943✔
3310
        for _, t := range stmt.targets {
1,807✔
3311
                selectors = append(selectors, t.Exp.selectors()...)
864✔
3312
        }
864✔
3313
        return selectors
943✔
3314
}
3315

3316
func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
950✔
3317
        scanSpecs, err := stmt.genScanSpecs(tx, params)
950✔
3318
        if err != nil {
966✔
3319
                return nil, err
16✔
3320
        }
16✔
3321

3322
        rowReader, err := stmt.ds.Resolve(ctx, tx, params, scanSpecs)
934✔
3323
        if err != nil {
937✔
3324
                return nil, err
3✔
3325
        }
3✔
3326
        defer func() {
1,862✔
3327
                if err != nil {
937✔
3328
                        rowReader.Close()
6✔
3329
                }
6✔
3330
        }()
3331

3332
        if stmt.joins != nil {
946✔
3333
                var jointRowReader *jointRowReader
15✔
3334
                jointRowReader, err = newJointRowReader(rowReader, stmt.joins)
15✔
3335
                if err != nil {
16✔
3336
                        return nil, err
1✔
3337
                }
1✔
3338
                rowReader = jointRowReader
14✔
3339
        }
3340

3341
        if stmt.where != nil {
1,457✔
3342
                rowReader = newConditionalRowReader(rowReader, stmt.where)
527✔
3343
        }
527✔
3344

3345
        if stmt.containsAggregations() || len(stmt.groupBy) > 0 {
1,010✔
3346
                if len(scanSpecs.groupBySortExps) > 0 {
92✔
3347
                        var sortRowReader *sortRowReader
12✔
3348
                        sortRowReader, err = newSortRowReader(rowReader, scanSpecs.groupBySortExps)
12✔
3349
                        if err != nil {
12✔
3350
                                return nil, err
×
3351
                        }
×
3352
                        rowReader = sortRowReader
12✔
3353
                }
3354

3355
                var groupedRowReader *groupedRowReader
80✔
3356
                groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy)
80✔
3357
                if err != nil {
82✔
3358
                        return nil, err
2✔
3359
                }
2✔
3360
                rowReader = groupedRowReader
78✔
3361

78✔
3362
                if stmt.having != nil {
82✔
3363
                        rowReader = newConditionalRowReader(rowReader, stmt.having)
4✔
3364
                }
4✔
3365
        }
3366

3367
        if len(scanSpecs.orderBySortExps) > 0 {
977✔
3368
                var sortRowReader *sortRowReader
49✔
3369
                sortRowReader, err = newSortRowReader(rowReader, stmt.orderBy)
49✔
3370
                if err != nil {
50✔
3371
                        return nil, err
1✔
3372
                }
1✔
3373
                rowReader = sortRowReader
48✔
3374
        }
3375

3376
        projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets)
927✔
3377
        if err != nil {
928✔
3378
                return nil, err
1✔
3379
        }
1✔
3380
        rowReader = projectedRowReader
926✔
3381

926✔
3382
        if stmt.distinct {
934✔
3383
                var distinctRowReader *distinctRowReader
8✔
3384
                distinctRowReader, err = newDistinctRowReader(ctx, rowReader)
8✔
3385
                if err != nil {
9✔
3386
                        return nil, err
1✔
3387
                }
1✔
3388
                rowReader = distinctRowReader
7✔
3389
        }
3390

3391
        if stmt.offset != nil {
975✔
3392
                var offset int
50✔
3393
                offset, err = evalExpAsInt(tx, stmt.offset, params)
50✔
3394
                if err != nil {
50✔
3395
                        return nil, fmt.Errorf("%w: invalid offset", err)
×
3396
                }
×
3397

3398
                rowReader = newOffsetRowReader(rowReader, offset)
50✔
3399
        }
3400

3401
        if stmt.limit != nil {
1,022✔
3402
                var limit int
97✔
3403
                limit, err = evalExpAsInt(tx, stmt.limit, params)
97✔
3404
                if err != nil {
97✔
3405
                        return nil, fmt.Errorf("%w: invalid limit", err)
×
3406
                }
×
3407

3408
                if limit < 0 {
97✔
3409
                        return nil, fmt.Errorf("%w: invalid limit", ErrIllegalArguments)
×
3410
                }
×
3411

3412
                if limit > 0 {
140✔
3413
                        rowReader = newLimitRowReader(rowReader, limit)
43✔
3414
                }
43✔
3415
        }
3416
        return rowReader, nil
925✔
3417
}
3418

3419
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
933✔
3420
        if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
939✔
3421
                if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
8✔
3422
                        return orderByExps, nil
2✔
3423
                }
2✔
3424

3425
                if ordExpsHasPrefix(groupByCols, orderByExps, stmt.Alias()) {
5✔
3426
                        for i := range orderByExps {
2✔
3427
                                groupByCols[i].descOrder = orderByExps[i].descOrder
1✔
3428
                        }
1✔
3429
                        return groupByCols, nil
1✔
3430
                }
3431
        }
3432
        return groupByCols, orderByExps
930✔
3433
}
3434

3435
func ordExpsHasPrefix(cols, prefix []*OrdExp, table string) bool {
10✔
3436
        if len(prefix) > len(cols) {
12✔
3437
                return false
2✔
3438
        }
2✔
3439

3440
        for i := range prefix {
17✔
3441
                ls := prefix[i].AsSelector()
9✔
3442
                rs := cols[i].AsSelector()
9✔
3443

9✔
3444
                if ls == nil || rs == nil {
9✔
3445
                        return false
×
3446
                }
×
3447

3448
                if EncodeSelector(ls.resolve(table)) != EncodeSelector(rs.resolve(table)) {
14✔
3449
                        return false
5✔
3450
                }
5✔
3451
        }
3452
        return true
3✔
3453
}
3454

3455
func (stmt *SelectStmt) groupByOrdExps() []*OrdExp {
950✔
3456
        groupByCols := stmt.groupBy
950✔
3457

950✔
3458
        ordExps := make([]*OrdExp, 0, len(groupByCols))
950✔
3459
        for _, col := range groupByCols {
997✔
3460
                ordExps = append(ordExps, &OrdExp{exp: col})
47✔
3461
        }
47✔
3462
        return ordExps
950✔
3463
}
3464

3465
func ordExpsHaveAggregations(exps []*OrdExp) bool {
7✔
3466
        for _, e := range exps {
17✔
3467
                if _, isAgg := e.exp.(*AggColSelector); isAgg {
11✔
3468
                        return true
1✔
3469
                }
1✔
3470
        }
3471
        return false
6✔
3472
}
3473

3474
func (stmt *SelectStmt) containsAggregations() bool {
1,541✔
3475
        for _, sel := range stmt.targetSelectors() {
3,117✔
3476
                _, isAgg := sel.(*AggColSelector)
1,576✔
3477
                if isAgg {
1,739✔
3478
                        return true
163✔
3479
                }
163✔
3480
        }
3481
        return false
1,378✔
3482
}
3483

3484
func evalExpAsInt(tx *SQLTx, exp ValueExp, params map[string]interface{}) (int, error) {
147✔
3485
        offset, err := exp.substitute(params)
147✔
3486
        if err != nil {
147✔
3487
                return 0, err
×
3488
        }
×
3489

3490
        texp, err := offset.reduce(tx, nil, "")
147✔
3491
        if err != nil {
147✔
3492
                return 0, err
×
3493
        }
×
3494

3495
        convVal, err := mayApplyImplicitConversion(texp.RawValue(), IntegerType)
147✔
3496
        if err != nil {
147✔
3497
                return 0, ErrInvalidValue
×
3498
        }
×
3499

3500
        num, ok := convVal.(int64)
147✔
3501
        if !ok {
147✔
3502
                return 0, ErrInvalidValue
×
3503
        }
×
3504

3505
        if num > math.MaxInt32 {
147✔
3506
                return 0, ErrInvalidValue
×
3507
        }
×
3508

3509
        return int(num), nil
147✔
3510
}
3511

3512
func (stmt *SelectStmt) Alias() string {
167✔
3513
        if stmt.as == "" {
333✔
3514
                return stmt.ds.Alias()
166✔
3515
        }
166✔
3516

3517
        return stmt.as
1✔
3518
}
3519

3520
func (stmt *SelectStmt) hasTxMetadata() bool {
873✔
3521
        for _, sel := range stmt.targetSelectors() {
1,679✔
3522
                switch s := sel.(type) {
806✔
3523
                case *ColSelector:
692✔
3524
                        if s.col == txMetadataCol {
693✔
3525
                                return true
1✔
3526
                        }
1✔
3527
                case *JSONSelector:
21✔
3528
                        if s.ColSelector.col == txMetadataCol {
24✔
3529
                                return true
3✔
3530
                        }
3✔
3531
                }
3532
        }
3533
        return false
869✔
3534
}
3535

3536
func (stmt *SelectStmt) genScanSpecs(tx *SQLTx, params map[string]interface{}) (*ScanSpecs, error) {
950✔
3537
        groupByCols, orderByCols := stmt.groupByOrdExps(), stmt.orderBy
950✔
3538

950✔
3539
        tableRef, isTableRef := stmt.ds.(*tableRef)
950✔
3540
        if !isTableRef {
1,010✔
3541
                groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
60✔
3542

60✔
3543
                return &ScanSpecs{
60✔
3544
                        groupBySortExps: groupByCols,
60✔
3545
                        orderBySortExps: orderByCols,
60✔
3546
                }, nil
60✔
3547
        }
60✔
3548

3549
        table, err := tableRef.referencedTable(tx)
890✔
3550
        if err != nil {
905✔
3551
                if tx.engine.tableResolveFor(tableRef.table) != nil {
16✔
3552
                        return &ScanSpecs{
1✔
3553
                                groupBySortExps: groupByCols,
1✔
3554
                                orderBySortExps: orderByCols,
1✔
3555
                        }, nil
1✔
3556
                }
1✔
3557
                return nil, err
14✔
3558
        }
3559

3560
        rangesByColID := make(map[uint32]*typedValueRange)
875✔
3561
        if stmt.where != nil {
1,391✔
3562
                err = stmt.where.selectorRanges(table, tableRef.Alias(), params, rangesByColID)
516✔
3563
                if err != nil {
518✔
3564
                        return nil, err
2✔
3565
                }
2✔
3566
        }
3567

3568
        preferredIndex, err := stmt.getPreferredIndex(table)
873✔
3569
        if err != nil {
873✔
3570
                return nil, err
×
3571
        }
×
3572

3573
        var sortingIndex *Index
873✔
3574
        if preferredIndex == nil {
1,716✔
3575
                sortingIndex = stmt.selectSortingIndex(groupByCols, orderByCols, table, rangesByColID)
843✔
3576
        } else {
873✔
3577
                sortingIndex = preferredIndex
30✔
3578
        }
30✔
3579

3580
        if sortingIndex == nil {
1,626✔
3581
                sortingIndex = table.primaryIndex
753✔
3582
        }
753✔
3583

3584
        if tableRef.history && !sortingIndex.IsPrimary() {
873✔
3585
                return nil, fmt.Errorf("%w: historical queries are supported over primary index", ErrIllegalArguments)
×
3586
        }
×
3587

3588
        var descOrder bool
873✔
3589
        if len(groupByCols) > 0 && sortingIndex.coversOrdCols(groupByCols, rangesByColID) {
890✔
3590
                groupByCols = nil
17✔
3591
        }
17✔
3592

3593
        if len(groupByCols) == 0 && len(orderByCols) > 0 && sortingIndex.coversOrdCols(orderByCols, rangesByColID) {
972✔
3594
                descOrder = orderByCols[0].descOrder
99✔
3595
                orderByCols = nil
99✔
3596
        }
99✔
3597

3598
        groupByCols, orderByCols = stmt.rearrangeOrdExps(groupByCols, orderByCols)
873✔
3599

873✔
3600
        return &ScanSpecs{
873✔
3601
                Index:             sortingIndex,
873✔
3602
                rangesByColID:     rangesByColID,
873✔
3603
                IncludeHistory:    tableRef.history,
873✔
3604
                IncludeTxMetadata: stmt.hasTxMetadata(),
873✔
3605
                DescOrder:         descOrder,
873✔
3606
                groupBySortExps:   groupByCols,
873✔
3607
                orderBySortExps:   orderByCols,
873✔
3608
        }, nil
873✔
3609
}
3610

3611
func (stmt *SelectStmt) selectSortingIndex(groupByCols, orderByCols []*OrdExp, table *Table, rangesByColId map[uint32]*typedValueRange) *Index {
843✔
3612
        sortCols := groupByCols
843✔
3613
        if len(sortCols) == 0 {
1,660✔
3614
                sortCols = orderByCols
817✔
3615
        }
817✔
3616

3617
        if len(sortCols) == 0 {
1,550✔
3618
                return nil
707✔
3619
        }
707✔
3620

3621
        for _, idx := range table.indexes {
358✔
3622
                if idx.coversOrdCols(sortCols, rangesByColId) {
312✔
3623
                        return idx
90✔
3624
                }
90✔
3625
        }
3626
        return nil
46✔
3627
}
3628

3629
func (stmt *SelectStmt) getPreferredIndex(table *Table) (*Index, error) {
873✔
3630
        if len(stmt.indexOn) == 0 {
1,716✔
3631
                return nil, nil
843✔
3632
        }
843✔
3633

3634
        cols := make([]*Column, len(stmt.indexOn))
30✔
3635
        for i, colName := range stmt.indexOn {
80✔
3636
                col, err := table.GetColumnByName(colName)
50✔
3637
                if err != nil {
50✔
3638
                        return nil, err
×
3639
                }
×
3640

3641
                cols[i] = col
50✔
3642
        }
3643
        return table.GetIndexByName(indexName(table.name, cols))
30✔
3644
}
3645

3646
type UnionStmt struct {
3647
        distinct    bool
3648
        left, right DataSource
3649
}
3650

3651
func (stmt *UnionStmt) readOnly() bool {
1✔
3652
        return true
1✔
3653
}
1✔
3654

3655
func (stmt *UnionStmt) requiredPrivileges() []SQLPrivilege {
1✔
3656
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3657
}
1✔
3658

3659
func (stmt *UnionStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3660
        err := stmt.left.inferParameters(ctx, tx, params)
1✔
3661
        if err != nil {
1✔
3662
                return err
×
3663
        }
×
3664
        return stmt.right.inferParameters(ctx, tx, params)
1✔
3665
}
3666

3667
func (stmt *UnionStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
9✔
3668
        _, err := stmt.left.execAt(ctx, tx, params)
9✔
3669
        if err != nil {
9✔
3670
                return tx, err
×
3671
        }
×
3672

3673
        return stmt.right.execAt(ctx, tx, params)
9✔
3674
}
3675

3676
func (stmt *UnionStmt) resolveUnionAll(ctx context.Context, tx *SQLTx, params map[string]interface{}) (ret RowReader, err error) {
11✔
3677
        leftRowReader, err := stmt.left.Resolve(ctx, tx, params, nil)
11✔
3678
        if err != nil {
12✔
3679
                return nil, err
1✔
3680
        }
1✔
3681
        defer func() {
20✔
3682
                if err != nil {
14✔
3683
                        leftRowReader.Close()
4✔
3684
                }
4✔
3685
        }()
3686

3687
        rightRowReader, err := stmt.right.Resolve(ctx, tx, params, nil)
10✔
3688
        if err != nil {
11✔
3689
                return nil, err
1✔
3690
        }
1✔
3691
        defer func() {
18✔
3692
                if err != nil {
12✔
3693
                        rightRowReader.Close()
3✔
3694
                }
3✔
3695
        }()
3696

3697
        rowReader, err := newUnionRowReader(ctx, []RowReader{leftRowReader, rightRowReader})
9✔
3698
        if err != nil {
12✔
3699
                return nil, err
3✔
3700
        }
3✔
3701

3702
        return rowReader, nil
6✔
3703
}
3704

3705
func (stmt *UnionStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) {
11✔
3706
        rowReader, err := stmt.resolveUnionAll(ctx, tx, params)
11✔
3707
        if err != nil {
16✔
3708
                return nil, err
5✔
3709
        }
5✔
3710
        defer func() {
12✔
3711
                if err != nil {
7✔
3712
                        rowReader.Close()
1✔
3713
                }
1✔
3714
        }()
3715

3716
        if stmt.distinct {
11✔
3717
                distinctReader, err := newDistinctRowReader(ctx, rowReader)
5✔
3718
                if err != nil {
6✔
3719
                        return nil, err
1✔
3720
                }
1✔
3721
                rowReader = distinctReader
4✔
3722
        }
3723

3724
        return rowReader, nil
5✔
3725
}
3726

3727
func (stmt *UnionStmt) Alias() string {
×
3728
        return ""
×
3729
}
×
3730

3731
func NewTableRef(table string, as string) *tableRef {
179✔
3732
        return &tableRef{
179✔
3733
                table: table,
179✔
3734
                as:    as,
179✔
3735
        }
179✔
3736
}
179✔
3737

3738
type tableRef struct {
3739
        table   string
3740
        history bool
3741
        period  period
3742
        as      string
3743
}
3744

3745
func (ref *tableRef) readOnly() bool {
1✔
3746
        return true
1✔
3747
}
1✔
3748

3749
func (ref *tableRef) requiredPrivileges() []SQLPrivilege {
1✔
3750
        return []SQLPrivilege{SQLPrivilegeSelect}
1✔
3751
}
1✔
3752

3753
type period struct {
3754
        start *openPeriod
3755
        end   *openPeriod
3756
}
3757

3758
type openPeriod struct {
3759
        inclusive bool
3760
        instant   periodInstant
3761
}
3762

3763
type periodInstant struct {
3764
        exp         ValueExp
3765
        instantType instantType
3766
}
3767

3768
type instantType = int
3769

3770
const (
3771
        txInstant instantType = iota
3772
        timeInstant
3773
)
3774

3775
func (i periodInstant) resolve(tx *SQLTx, params map[string]interface{}, asc, inclusive bool) (uint64, error) {
81✔
3776
        exp, err := i.exp.substitute(params)
81✔
3777
        if err != nil {
81✔
3778
                return 0, err
×
3779
        }
×
3780

3781
        instantVal, err := exp.reduce(tx, nil, "")
81✔
3782
        if err != nil {
83✔
3783
                return 0, err
2✔
3784
        }
2✔
3785

3786
        if i.instantType == txInstant {
124✔
3787
                txID, ok := instantVal.RawValue().(int64)
45✔
3788
                if !ok {
45✔
3789
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %s given", ErrIllegalArguments, instantVal.Type())
×
3790
                }
×
3791

3792
                if txID <= 0 {
52✔
3793
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be a positive integer, %d given", ErrIllegalArguments, txID)
7✔
3794
                }
7✔
3795

3796
                if inclusive {
61✔
3797
                        return uint64(txID), nil
23✔
3798
                }
23✔
3799

3800
                if asc {
26✔
3801
                        return uint64(txID + 1), nil
11✔
3802
                }
11✔
3803

3804
                if txID <= 1 {
5✔
3805
                        return 0, fmt.Errorf("%w: invalid tx range, tx ID must be greater than 1, %d given", ErrIllegalArguments, txID)
1✔
3806
                }
1✔
3807

3808
                return uint64(txID - 1), nil
3✔
3809
        } else {
34✔
3810

34✔
3811
                var ts time.Time
34✔
3812

34✔
3813
                if instantVal.Type() == TimestampType {
67✔
3814
                        ts = instantVal.RawValue().(time.Time)
33✔
3815
                } else {
34✔
3816
                        conv, err := getConverter(instantVal.Type(), TimestampType)
1✔
3817
                        if err != nil {
1✔
3818
                                return 0, err
×
3819
                        }
×
3820

3821
                        tval, err := conv(instantVal)
1✔
3822
                        if err != nil {
1✔
3823
                                return 0, err
×
3824
                        }
×
3825

3826
                        ts = tval.RawValue().(time.Time)
1✔
3827
                }
3828

3829
                sts := ts
34✔
3830

34✔
3831
                if asc {
57✔
3832
                        if !inclusive {
34✔
3833
                                sts = sts.Add(1 * time.Second)
11✔
3834
                        }
11✔
3835

3836
                        txHdr, err := tx.engine.store.FirstTxSince(sts)
23✔
3837
                        if err != nil {
34✔
3838
                                return 0, err
11✔
3839
                        }
11✔
3840

3841
                        return txHdr.ID, nil
12✔
3842
                }
3843

3844
                if !inclusive {
11✔
3845
                        sts = sts.Add(-1 * time.Second)
×
3846
                }
×
3847

3848
                txHdr, err := tx.engine.store.LastTxUntil(sts)
11✔
3849
                if err != nil {
11✔
3850
                        return 0, err
×
3851
                }
×
3852

3853
                return txHdr.ID, nil
11✔
3854
        }
3855
}
3856

3857
func (stmt *tableRef) referencedTable(tx *SQLTx) (*Table, error) {
3,936✔
3858
        table, err := tx.catalog.GetTableByName(stmt.table)
3,936✔
3859
        if err != nil {
3,956✔
3860
                return nil, err
20✔
3861
        }
20✔
3862
        return table, nil
3,916✔
3863
}
3864

3865
func (stmt *tableRef) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
3866
        return nil
1✔
3867
}
1✔
3868

3869
func (stmt *tableRef) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3870
        return tx, nil
×
3871
}
×
3872

3873
func (stmt *tableRef) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
903✔
3874
        if tx == nil {
903✔
3875
                return nil, ErrIllegalArguments
×
3876
        }
×
3877

3878
        table, err := stmt.referencedTable(tx)
903✔
3879
        if err == nil {
1,805✔
3880
                return newRawRowReader(tx, params, table, stmt.period, stmt.as, scanSpecs)
902✔
3881
        }
902✔
3882

3883
        if resolver := tx.engine.tableResolveFor(stmt.table); resolver != nil {
2✔
3884
                return resolver.Resolve(ctx, tx, stmt.Alias())
1✔
3885
        }
1✔
NEW
3886
        return nil, err
×
3887
}
3888

3889
func (stmt *tableRef) Alias() string {
681✔
3890
        if stmt.as == "" {
1,202✔
3891
                return stmt.table
521✔
3892
        }
521✔
3893
        return stmt.as
160✔
3894
}
3895

3896
type valuesDataSource struct {
3897
        inferTypes bool
3898
        rows       []*RowSpec
3899
}
3900

3901
func NewValuesDataSource(rows []*RowSpec) *valuesDataSource {
120✔
3902
        return &valuesDataSource{
120✔
3903
                rows: rows,
120✔
3904
        }
120✔
3905
}
120✔
3906

3907
func (ds *valuesDataSource) readOnly() bool {
×
3908
        return true
×
3909
}
×
3910

3911
func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege {
97✔
3912
        return nil
97✔
3913
}
97✔
3914

3915
func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
3916
        return tx, nil
×
3917
}
×
3918

3919
func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
×
3920
        return nil
×
3921
}
×
3922

3923
func (ds *valuesDataSource) Alias() string {
×
3924
        return ""
×
3925
}
×
3926

3927
func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) {
2,116✔
3928
        if tx == nil {
2,116✔
3929
                return nil, ErrIllegalArguments
×
3930
        }
×
3931

3932
        cols := make([]ColDescriptor, len(ds.rows[0].Values))
2,116✔
3933
        for i := range cols {
9,752✔
3934
                cols[i] = ColDescriptor{
7,636✔
3935
                        Type:   AnyType,
7,636✔
3936
                        Column: fmt.Sprintf("col%d", i),
7,636✔
3937
                }
7,636✔
3938
        }
7,636✔
3939

3940
        emptyColsDesc, emptyParams := map[string]ColDescriptor{}, map[string]string{}
2,116✔
3941

2,116✔
3942
        if ds.inferTypes {
2,122✔
3943
                for i := 0; i < len(cols); i++ {
30✔
3944
                        t := AnyType
24✔
3945
                        for j := 0; j < len(ds.rows); j++ {
110✔
3946
                                e, err := ds.rows[j].Values[i].substitute(params)
86✔
3947
                                if err != nil {
86✔
3948
                                        return nil, err
×
3949
                                }
×
3950

3951
                                it, err := e.inferType(emptyColsDesc, emptyParams, "")
86✔
3952
                                if err != nil {
86✔
3953
                                        return nil, err
×
3954
                                }
×
3955

3956
                                if t == AnyType {
110✔
3957
                                        t = it
24✔
3958
                                } else if t != it && it != AnyType {
88✔
3959
                                        return nil, fmt.Errorf("cannot match types %s and %s", t, it)
2✔
3960
                                }
2✔
3961
                        }
3962
                        cols[i].Type = t
22✔
3963
                }
3964
        }
3965

3966
        values := make([][]ValueExp, len(ds.rows))
2,114✔
3967
        for i, rowSpec := range ds.rows {
4,335✔
3968
                values[i] = rowSpec.Values
2,221✔
3969
        }
2,221✔
3970
        return NewValuesRowReader(tx, params, cols, ds.inferTypes, "values", values)
2,114✔
3971
}
3972

3973
type JoinSpec struct {
3974
        joinType JoinType
3975
        ds       DataSource
3976
        cond     ValueExp
3977
        indexOn  []string
3978
}
3979

3980
type OrdExp struct {
3981
        exp       ValueExp
3982
        descOrder bool
3983
}
3984

3985
func (oc *OrdExp) AsSelector() Selector {
708✔
3986
        sel, ok := oc.exp.(Selector)
708✔
3987
        if ok {
1,362✔
3988
                return sel
654✔
3989
        }
654✔
3990
        return nil
54✔
3991
}
3992

3993
func NewOrdCol(table string, col string, descOrder bool) *OrdExp {
1✔
3994
        return &OrdExp{
1✔
3995
                exp:       NewColSelector(table, col),
1✔
3996
                descOrder: descOrder,
1✔
3997
        }
1✔
3998
}
1✔
3999

4000
type Selector interface {
4001
        ValueExp
4002
        resolve(implicitTable string) (aggFn, table, col string)
4003
}
4004

4005
type ColSelector struct {
4006
        table string
4007
        col   string
4008
}
4009

4010
func NewColSelector(table, col string) *ColSelector {
126✔
4011
        return &ColSelector{
126✔
4012
                table: table,
126✔
4013
                col:   col,
126✔
4014
        }
126✔
4015
}
126✔
4016

4017
func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) {
867,560✔
4018
        table = implicitTable
867,560✔
4019
        if sel.table != "" {
1,168,526✔
4020
                table = sel.table
300,966✔
4021
        }
300,966✔
4022
        return "", table, sel.col
867,560✔
4023
}
4024

4025
func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
685✔
4026
        _, table, col := sel.resolve(implicitTable)
685✔
4027
        encSel := EncodeSelector("", table, col)
685✔
4028

685✔
4029
        desc, ok := cols[encSel]
685✔
4030
        if !ok {
688✔
4031
                return AnyType, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
3✔
4032
        }
3✔
4033
        return desc.Type, nil
682✔
4034
}
4035

4036
func (sel *ColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
15✔
4037
        _, table, col := sel.resolve(implicitTable)
15✔
4038
        encSel := EncodeSelector("", table, col)
15✔
4039

15✔
4040
        desc, ok := cols[encSel]
15✔
4041
        if !ok {
17✔
4042
                return fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
2✔
4043
        }
2✔
4044

4045
        if desc.Type != t {
16✔
4046
                return fmt.Errorf("%w: %v(%s) can not be interpreted as type %v", ErrInvalidTypes, desc.Type, encSel, t)
3✔
4047
        }
3✔
4048

4049
        return nil
10✔
4050
}
4051

4052
func (sel *ColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
161,807✔
4053
        return sel, nil
161,807✔
4054
}
161,807✔
4055

4056
func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
715,093✔
4057
        if row == nil {
715,094✔
4058
                return nil, fmt.Errorf("%w: no row to evaluate in current context", ErrInvalidValue)
1✔
4059
        }
1✔
4060

4061
        aggFn, table, col := sel.resolve(implicitTable)
715,092✔
4062

715,092✔
4063
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
715,092✔
4064
        if !ok {
715,099✔
4065
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, col)
7✔
4066
        }
7✔
4067
        return v, nil
715,085✔
4068
}
4069

4070
func (sel *ColSelector) selectors() []Selector {
913✔
4071
        return []Selector{sel}
913✔
4072
}
913✔
4073

4074
func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
568✔
4075
        aggFn, table, col := sel.resolve(implicitTable)
568✔
4076

568✔
4077
        v, ok := row.ValuesBySelector[EncodeSelector(aggFn, table, col)]
568✔
4078
        if !ok {
846✔
4079
                return sel
278✔
4080
        }
278✔
4081

4082
        return v
290✔
4083
}
4084

4085
func (sel *ColSelector) isConstant() bool {
12✔
4086
        return false
12✔
4087
}
12✔
4088

4089
func (sel *ColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
11✔
4090
        return nil
11✔
4091
}
11✔
4092

4093
func (sel *ColSelector) String() string {
42✔
4094
        return sel.col
42✔
4095
}
42✔
4096

4097
type AggColSelector struct {
4098
        aggFn AggregateFn
4099
        table string
4100
        col   string
4101
}
4102

4103
func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector {
16✔
4104
        return &AggColSelector{
16✔
4105
                aggFn: aggFn,
16✔
4106
                table: table,
16✔
4107
                col:   col,
16✔
4108
        }
16✔
4109
}
16✔
4110

4111
func EncodeSelector(aggFn, table, col string) string {
1,417,262✔
4112
        return aggFn + "(" + table + "." + col + ")"
1,417,262✔
4113
}
1,417,262✔
4114

4115
func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col string) {
1,586✔
4116
        table = implicitTable
1,586✔
4117
        if sel.table != "" {
1,717✔
4118
                table = sel.table
131✔
4119
        }
131✔
4120
        return sel.aggFn, table, sel.col
1,586✔
4121
}
4122

4123
func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
36✔
4124
        if sel.aggFn == COUNT {
55✔
4125
                return IntegerType, nil
19✔
4126
        }
19✔
4127

4128
        colSelector := &ColSelector{table: sel.table, col: sel.col}
17✔
4129

17✔
4130
        if sel.aggFn == SUM || sel.aggFn == AVG {
24✔
4131
                t, err := colSelector.inferType(cols, params, implicitTable)
7✔
4132
                if err != nil {
7✔
4133
                        return AnyType, err
×
4134
                }
×
4135

4136
                if t != IntegerType && t != Float64Type {
7✔
4137
                        return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
×
4138

×
4139
                }
×
4140

4141
                return t, nil
7✔
4142
        }
4143

4144
        return colSelector.inferType(cols, params, implicitTable)
10✔
4145
}
4146

4147
func (sel *AggColSelector) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
8✔
4148
        if sel.aggFn == COUNT {
10✔
4149
                if t != IntegerType {
3✔
4150
                        return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4151
                }
1✔
4152
                return nil
1✔
4153
        }
4154

4155
        colSelector := &ColSelector{table: sel.table, col: sel.col}
6✔
4156

6✔
4157
        if sel.aggFn == SUM || sel.aggFn == AVG {
10✔
4158
                if t != IntegerType && t != Float64Type {
5✔
4159
                        return fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, t)
1✔
4160
                }
1✔
4161
        }
4162

4163
        return colSelector.requiresType(t, cols, params, implicitTable)
5✔
4164
}
4165

4166
func (sel *AggColSelector) substitute(params map[string]interface{}) (ValueExp, error) {
412✔
4167
        return sel, nil
412✔
4168
}
412✔
4169

4170
func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
459✔
4171
        if row == nil {
460✔
4172
                return nil, fmt.Errorf("%w: no row to evaluate aggregation (%s) in current context", ErrInvalidValue, sel.aggFn)
1✔
4173
        }
1✔
4174

4175
        v, ok := row.ValuesBySelector[EncodeSelector(sel.resolve(implicitTable))]
458✔
4176
        if !ok {
459✔
4177
                return nil, fmt.Errorf("%w (%s)", ErrColumnDoesNotExist, sel.col)
1✔
4178
        }
1✔
4179
        return v, nil
457✔
4180
}
4181

4182
func (sel *AggColSelector) selectors() []Selector {
232✔
4183
        return []Selector{sel}
232✔
4184
}
232✔
4185

4186
func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4187
        return sel
×
4188
}
×
4189

4190
func (sel *AggColSelector) isConstant() bool {
1✔
4191
        return false
1✔
4192
}
1✔
4193

4194
func (sel *AggColSelector) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
×
4195
        return nil
×
4196
}
×
4197

4198
func (sel *AggColSelector) String() string {
×
4199
        return sel.aggFn + "(" + sel.col + ")"
×
4200
}
×
4201

4202
type NumExp struct {
4203
        op          NumOperator
4204
        left, right ValueExp
4205
}
4206

4207
func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
17✔
4208
        // First step - check if we can infer the type of sub-expressions
17✔
4209
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
17✔
4210
        if err != nil {
17✔
4211
                return AnyType, err
×
4212
        }
×
4213
        if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType {
17✔
4214
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft)
×
4215
        }
×
4216

4217
        tright, err := bexp.right.inferType(cols, params, implicitTable)
17✔
4218
        if err != nil {
17✔
4219
                return AnyType, err
×
4220
        }
×
4221
        if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType {
19✔
4222
                return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright)
2✔
4223
        }
2✔
4224

4225
        if tleft == IntegerType && tright == IntegerType {
19✔
4226
                // Both sides are integer types - the result is also integer
4✔
4227
                return IntegerType, nil
4✔
4228
        }
4✔
4229

4230
        if tleft != AnyType && tright != AnyType {
20✔
4231
                // Both sides have concrete types but at least one of them is float
9✔
4232
                return Float64Type, nil
9✔
4233
        }
9✔
4234

4235
        // Both sides are ambiguous
4236
        return AnyType, nil
2✔
4237
}
4238

4239
func copyParams(params map[string]SQLValueType) map[string]SQLValueType {
11✔
4240
        ret := make(map[string]SQLValueType, len(params))
11✔
4241
        for k, v := range params {
15✔
4242
                ret[k] = v
4✔
4243
        }
4✔
4244
        return ret
11✔
4245
}
4246

4247
func restoreParams(params, restore map[string]SQLValueType) {
2✔
4248
        for k := range params {
2✔
4249
                delete(params, k)
×
4250
        }
×
4251
        for k, v := range restore {
2✔
4252
                params[k] = v
×
4253
        }
×
4254
}
4255

4256
func (bexp *NumExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4257
        if t != IntegerType && t != Float64Type {
8✔
4258
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
1✔
4259
        }
1✔
4260

4261
        floatArgs := 2
6✔
4262
        paramsOrig := copyParams(params)
6✔
4263
        err := bexp.left.requiresType(t, cols, params, implicitTable)
6✔
4264
        if err != nil && t == Float64Type {
7✔
4265
                restoreParams(params, paramsOrig)
1✔
4266
                floatArgs--
1✔
4267
                err = bexp.left.requiresType(IntegerType, cols, params, implicitTable)
1✔
4268
        }
1✔
4269
        if err != nil {
7✔
4270
                return err
1✔
4271
        }
1✔
4272

4273
        paramsOrig = copyParams(params)
5✔
4274
        err = bexp.right.requiresType(t, cols, params, implicitTable)
5✔
4275
        if err != nil && t == Float64Type {
6✔
4276
                restoreParams(params, paramsOrig)
1✔
4277
                floatArgs--
1✔
4278
                err = bexp.right.requiresType(IntegerType, cols, params, implicitTable)
1✔
4279
        }
1✔
4280
        if err != nil {
7✔
4281
                return err
2✔
4282
        }
2✔
4283

4284
        if t == Float64Type && floatArgs == 0 {
3✔
4285
                // Currently this case requires explicit float cast
×
4286
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
×
4287
        }
×
4288

4289
        return nil
3✔
4290
}
4291

4292
func (bexp *NumExp) substitute(params map[string]interface{}) (ValueExp, error) {
187✔
4293
        rlexp, err := bexp.left.substitute(params)
187✔
4294
        if err != nil {
187✔
4295
                return nil, err
×
4296
        }
×
4297

4298
        rrexp, err := bexp.right.substitute(params)
187✔
4299
        if err != nil {
187✔
4300
                return nil, err
×
4301
        }
×
4302

4303
        bexp.left = rlexp
187✔
4304
        bexp.right = rrexp
187✔
4305

187✔
4306
        return bexp, nil
187✔
4307
}
4308

4309
func (bexp *NumExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
124,244✔
4310
        vl, err := bexp.left.reduce(tx, row, implicitTable)
124,244✔
4311
        if err != nil {
124,244✔
4312
                return nil, err
×
4313
        }
×
4314

4315
        vr, err := bexp.right.reduce(tx, row, implicitTable)
124,244✔
4316
        if err != nil {
124,244✔
4317
                return nil, err
×
4318
        }
×
4319

4320
        vl = unwrapJSON(vl)
124,244✔
4321
        vr = unwrapJSON(vr)
124,244✔
4322

124,244✔
4323
        return applyNumOperator(bexp.op, vl, vr)
124,244✔
4324
}
4325

4326
func unwrapJSON(v TypedValue) TypedValue {
248,488✔
4327
        if jsonVal, ok := v.(*JSON); ok {
248,588✔
4328
                if sv, isSimple := jsonVal.castToTypedValue(); isSimple {
200✔
4329
                        return sv
100✔
4330
                }
100✔
4331
        }
4332
        return v
248,388✔
4333
}
4334

4335
func (bexp *NumExp) selectors() []Selector {
13✔
4336
        return append(bexp.left.selectors(), bexp.right.selectors()...)
13✔
4337
}
13✔
4338

4339
func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4340
        return &NumExp{
1✔
4341
                op:    bexp.op,
1✔
4342
                left:  bexp.left.reduceSelectors(row, implicitTable),
1✔
4343
                right: bexp.right.reduceSelectors(row, implicitTable),
1✔
4344
        }
1✔
4345
}
1✔
4346

4347
func (bexp *NumExp) isConstant() bool {
5✔
4348
        return bexp.left.isConstant() && bexp.right.isConstant()
5✔
4349
}
5✔
4350

4351
func (bexp *NumExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
5✔
4352
        return nil
5✔
4353
}
5✔
4354

4355
func (bexp *NumExp) String() string {
18✔
4356
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), NumOperatorString(bexp.op), bexp.right.String())
18✔
4357
}
18✔
4358

4359
type NotBoolExp struct {
4360
        exp ValueExp
4361
}
4362

4363
func (bexp *NotBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
2✔
4364
        err := bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
2✔
4365
        if err != nil {
2✔
4366
                return AnyType, err
×
4367
        }
×
4368

4369
        return BooleanType, nil
2✔
4370
}
4371

4372
func (bexp *NotBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
6✔
4373
        if t != BooleanType {
7✔
4374
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4375
        }
1✔
4376

4377
        return bexp.exp.requiresType(BooleanType, cols, params, implicitTable)
5✔
4378
}
4379

4380
func (bexp *NotBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
22✔
4381
        rexp, err := bexp.exp.substitute(params)
22✔
4382
        if err != nil {
22✔
4383
                return nil, err
×
4384
        }
×
4385

4386
        bexp.exp = rexp
22✔
4387

22✔
4388
        return bexp, nil
22✔
4389
}
4390

4391
func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
22✔
4392
        v, err := bexp.exp.reduce(tx, row, implicitTable)
22✔
4393
        if err != nil {
22✔
4394
                return nil, err
×
4395
        }
×
4396

4397
        r, isBool := v.RawValue().(bool)
22✔
4398
        if !isBool {
22✔
4399
                return nil, ErrInvalidCondition
×
4400
        }
×
4401

4402
        return &Bool{val: !r}, nil
22✔
4403
}
4404

4405
func (bexp *NotBoolExp) selectors() []Selector {
×
4406
        return bexp.exp.selectors()
×
4407
}
×
4408

4409
func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
×
4410
        return &NotBoolExp{
×
4411
                exp: bexp.exp.reduceSelectors(row, implicitTable),
×
4412
        }
×
4413
}
×
4414

4415
func (bexp *NotBoolExp) isConstant() bool {
1✔
4416
        return bexp.exp.isConstant()
1✔
4417
}
1✔
4418

4419
func (bexp *NotBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
7✔
4420
        return nil
7✔
4421
}
7✔
4422

4423
func (bexp *NotBoolExp) String() string {
12✔
4424
        return fmt.Sprintf("(NOT %s)", bexp.exp.String())
12✔
4425
}
12✔
4426

4427
type LikeBoolExp struct {
4428
        val     ValueExp
4429
        notLike bool
4430
        pattern ValueExp
4431
}
4432

4433
func NewLikeBoolExp(val ValueExp, notLike bool, pattern ValueExp) *LikeBoolExp {
4✔
4434
        return &LikeBoolExp{
4✔
4435
                val:     val,
4✔
4436
                notLike: notLike,
4✔
4437
                pattern: pattern,
4✔
4438
        }
4✔
4439
}
4✔
4440

4441
func (bexp *LikeBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
3✔
4442
        if bexp.val == nil || bexp.pattern == nil {
4✔
4443
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4444
        }
1✔
4445

4446
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
2✔
4447
        if err != nil {
3✔
4448
                return AnyType, fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4449
        }
1✔
4450

4451
        return BooleanType, nil
1✔
4452
}
4453

4454
func (bexp *LikeBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
7✔
4455
        if bexp.val == nil || bexp.pattern == nil {
9✔
4456
                return fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
2✔
4457
        }
2✔
4458

4459
        if t != BooleanType {
7✔
4460
                return fmt.Errorf("error using the value of the LIKE operator as %s: %w", t, ErrInvalidTypes)
2✔
4461
        }
2✔
4462

4463
        err := bexp.pattern.requiresType(VarcharType, cols, params, implicitTable)
3✔
4464
        if err != nil {
4✔
4465
                return fmt.Errorf("error in 'LIKE' clause: %w", err)
1✔
4466
        }
1✔
4467

4468
        return nil
2✔
4469
}
4470

4471
func (bexp *LikeBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
135✔
4472
        if bexp.val == nil || bexp.pattern == nil {
136✔
4473
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4474
        }
1✔
4475

4476
        val, err := bexp.val.substitute(params)
134✔
4477
        if err != nil {
134✔
4478
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4479
        }
×
4480

4481
        pattern, err := bexp.pattern.substitute(params)
134✔
4482
        if err != nil {
134✔
4483
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4484
        }
×
4485

4486
        return &LikeBoolExp{
134✔
4487
                val:     val,
134✔
4488
                notLike: bexp.notLike,
134✔
4489
                pattern: pattern,
134✔
4490
        }, nil
134✔
4491
}
4492

4493
func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
143✔
4494
        if bexp.val == nil || bexp.pattern == nil {
144✔
4495
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", ErrInvalidCondition)
1✔
4496
        }
1✔
4497

4498
        rval, err := bexp.val.reduce(tx, row, implicitTable)
142✔
4499
        if err != nil {
142✔
4500
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4501
        }
×
4502

4503
        if rval.IsNull() {
143✔
4504
                return &Bool{val: bexp.notLike}, nil
1✔
4505
        }
1✔
4506

4507
        rvalStr, ok := rval.RawValue().(string)
141✔
4508
        if !ok {
142✔
4509
                return nil, fmt.Errorf("error in 'LIKE' clause: %w (expecting %s)", ErrInvalidTypes, VarcharType)
1✔
4510
        }
1✔
4511

4512
        rpattern, err := bexp.pattern.reduce(tx, row, implicitTable)
140✔
4513
        if err != nil {
140✔
4514
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4515
        }
×
4516

4517
        if rpattern.Type() != VarcharType {
140✔
4518
                return nil, fmt.Errorf("error evaluating 'LIKE' clause: %w", ErrInvalidTypes)
×
4519
        }
×
4520

4521
        matched, err := regexp.MatchString(rpattern.RawValue().(string), rvalStr)
140✔
4522
        if err != nil {
140✔
4523
                return nil, fmt.Errorf("error in 'LIKE' clause: %w", err)
×
4524
        }
×
4525

4526
        return &Bool{val: matched != bexp.notLike}, nil
140✔
4527
}
4528

4529
func (bexp *LikeBoolExp) selectors() []Selector {
1✔
4530
        return bexp.val.selectors()
1✔
4531
}
1✔
4532

4533
func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
4534
        return bexp
1✔
4535
}
1✔
4536

4537
func (bexp *LikeBoolExp) isConstant() bool {
2✔
4538
        return false
2✔
4539
}
2✔
4540

4541
func (bexp *LikeBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
8✔
4542
        return nil
8✔
4543
}
8✔
4544

4545
func (bexp *LikeBoolExp) String() string {
5✔
4546
        fmtStr := "(%s LIKE %s)"
5✔
4547
        if bexp.notLike {
6✔
4548
                fmtStr = "(%s NOT LIKE %s)"
1✔
4549
        }
1✔
4550
        return fmt.Sprintf(fmtStr, bexp.val.String(), bexp.pattern.String())
5✔
4551
}
4552

4553
type CmpBoolExp struct {
4554
        op          CmpOperator
4555
        left, right ValueExp
4556
}
4557

4558
func NewCmpBoolExp(op CmpOperator, left, right ValueExp) *CmpBoolExp {
67✔
4559
        return &CmpBoolExp{
67✔
4560
                op:    op,
67✔
4561
                left:  left,
67✔
4562
                right: right,
67✔
4563
        }
67✔
4564
}
67✔
4565

4566
func (bexp *CmpBoolExp) Left() ValueExp {
×
4567
        return bexp.left
×
4568
}
×
4569

4570
func (bexp *CmpBoolExp) Right() ValueExp {
×
4571
        return bexp.right
×
4572
}
×
4573

4574
func (bexp *CmpBoolExp) OP() CmpOperator {
×
4575
        return bexp.op
×
4576
}
×
4577

4578
func (bexp *CmpBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
64✔
4579
        tleft, err := bexp.left.inferType(cols, params, implicitTable)
64✔
4580
        if err != nil {
64✔
4581
                return AnyType, err
×
4582
        }
×
4583

4584
        tright, err := bexp.right.inferType(cols, params, implicitTable)
64✔
4585
        if err != nil {
66✔
4586
                return AnyType, err
2✔
4587
        }
2✔
4588

4589
        // unification step
4590

4591
        if tleft == tright {
75✔
4592
                return BooleanType, nil
13✔
4593
        }
13✔
4594

4595
        _, ok := coerceTypes(tleft, tright)
49✔
4596
        if !ok {
53✔
4597
                return AnyType, fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, tleft, tright)
4✔
4598
        }
4✔
4599

4600
        if tleft == AnyType {
49✔
4601
                err = bexp.left.requiresType(tright, cols, params, implicitTable)
4✔
4602
                if err != nil {
4✔
4603
                        return AnyType, err
×
4604
                }
×
4605
        }
4606

4607
        if tright == AnyType {
85✔
4608
                err = bexp.right.requiresType(tleft, cols, params, implicitTable)
40✔
4609
                if err != nil {
41✔
4610
                        return AnyType, err
1✔
4611
                }
1✔
4612
        }
4613
        return BooleanType, nil
44✔
4614
}
4615

4616
func coerceTypes(t1, t2 SQLValueType) (SQLValueType, bool) {
49✔
4617
        switch {
49✔
4618
        case t1 == t2:
×
4619
                return t1, true
×
4620
        case t1 == AnyType:
4✔
4621
                return t2, true
4✔
4622
        case t2 == AnyType:
40✔
4623
                return t1, true
40✔
4624
        case (t1 == IntegerType && t2 == Float64Type) ||
4625
                (t1 == Float64Type && t2 == IntegerType):
1✔
4626
                return Float64Type, true
1✔
4627
        }
4628
        return "", false
4✔
4629
}
4630

4631
func (bexp *CmpBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
41✔
4632
        if t != BooleanType {
42✔
4633
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
1✔
4634
        }
1✔
4635

4636
        _, err := bexp.inferType(cols, params, implicitTable)
40✔
4637
        return err
40✔
4638
}
4639

4640
func (bexp *CmpBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
14,312✔
4641
        rlexp, err := bexp.left.substitute(params)
14,312✔
4642
        if err != nil {
14,312✔
4643
                return nil, err
×
4644
        }
×
4645

4646
        rrexp, err := bexp.right.substitute(params)
14,312✔
4647
        if err != nil {
14,313✔
4648
                return nil, err
1✔
4649
        }
1✔
4650

4651
        bexp.left = rlexp
14,311✔
4652
        bexp.right = rrexp
14,311✔
4653

14,311✔
4654
        return bexp, nil
14,311✔
4655
}
4656

4657
func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
13,911✔
4658
        vl, err := bexp.left.reduce(tx, row, implicitTable)
13,911✔
4659
        if err != nil {
13,914✔
4660
                return nil, err
3✔
4661
        }
3✔
4662

4663
        vr, err := bexp.right.reduce(tx, row, implicitTable)
13,908✔
4664
        if err != nil {
13,910✔
4665
                return nil, err
2✔
4666
        }
2✔
4667

4668
        r, err := vl.Compare(vr)
13,906✔
4669
        if err != nil {
13,910✔
4670
                return nil, err
4✔
4671
        }
4✔
4672

4673
        return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil
13,902✔
4674
}
4675

4676
func (bexp *CmpBoolExp) selectors() []Selector {
12✔
4677
        return append(bexp.left.selectors(), bexp.right.selectors()...)
12✔
4678
}
12✔
4679

4680
func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
282✔
4681
        return &CmpBoolExp{
282✔
4682
                op:    bexp.op,
282✔
4683
                left:  bexp.left.reduceSelectors(row, implicitTable),
282✔
4684
                right: bexp.right.reduceSelectors(row, implicitTable),
282✔
4685
        }
282✔
4686
}
282✔
4687

4688
func (bexp *CmpBoolExp) isConstant() bool {
2✔
4689
        return bexp.left.isConstant() && bexp.right.isConstant()
2✔
4690
}
2✔
4691

4692
func (bexp *CmpBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
604✔
4693
        matchingFunc := func(_, right ValueExp) (*ColSelector, ValueExp, bool) {
1,418✔
4694
                s, isSel := bexp.left.(*ColSelector)
814✔
4695
                if isSel && s.col != revCol && bexp.right.isConstant() {
1,208✔
4696
                        return s, right, true
394✔
4697
                }
394✔
4698
                return nil, nil, false
420✔
4699
        }
4700

4701
        sel, c, ok := matchingFunc(bexp.left, bexp.right)
604✔
4702
        if !ok {
814✔
4703
                sel, c, ok = matchingFunc(bexp.right, bexp.left)
210✔
4704
        }
210✔
4705

4706
        if !ok {
814✔
4707
                return nil
210✔
4708
        }
210✔
4709

4710
        aggFn, t, col := sel.resolve(table.name)
394✔
4711
        if aggFn != "" || t != asTable {
408✔
4712
                return nil
14✔
4713
        }
14✔
4714

4715
        column, err := table.GetColumnByName(col)
380✔
4716
        if err != nil {
381✔
4717
                return err
1✔
4718
        }
1✔
4719

4720
        val, err := c.substitute(params)
379✔
4721
        if errors.Is(err, ErrMissingParameter) {
438✔
4722
                // TODO: not supported when parameters are not provided during query resolution
59✔
4723
                return nil
59✔
4724
        }
59✔
4725
        if err != nil {
320✔
4726
                return err
×
4727
        }
×
4728

4729
        rval, err := val.reduce(nil, nil, table.name)
320✔
4730
        if err != nil {
321✔
4731
                return err
1✔
4732
        }
1✔
4733

4734
        return updateRangeFor(column.id, rval, bexp.op, rangesByColID)
319✔
4735
}
4736

4737
func (bexp *CmpBoolExp) String() string {
20✔
4738
        opStr := CmpOperatorToString(bexp.op)
20✔
4739
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), opStr, bexp.right.String())
20✔
4740
}
20✔
4741

4742
func updateRangeFor(colID uint32, val TypedValue, cmp CmpOperator, rangesByColID map[uint32]*typedValueRange) error {
319✔
4743
        currRange, ranged := rangesByColID[colID]
319✔
4744
        var newRange *typedValueRange
319✔
4745

319✔
4746
        switch cmp {
319✔
4747
        case EQ:
250✔
4748
                {
500✔
4749
                        newRange = &typedValueRange{
250✔
4750
                                lRange: &typedValueSemiRange{
250✔
4751
                                        val:       val,
250✔
4752
                                        inclusive: true,
250✔
4753
                                },
250✔
4754
                                hRange: &typedValueSemiRange{
250✔
4755
                                        val:       val,
250✔
4756
                                        inclusive: true,
250✔
4757
                                },
250✔
4758
                        }
250✔
4759
                }
250✔
4760
        case LT:
13✔
4761
                {
26✔
4762
                        newRange = &typedValueRange{
13✔
4763
                                hRange: &typedValueSemiRange{
13✔
4764
                                        val: val,
13✔
4765
                                },
13✔
4766
                        }
13✔
4767
                }
13✔
4768
        case LE:
10✔
4769
                {
20✔
4770
                        newRange = &typedValueRange{
10✔
4771
                                hRange: &typedValueSemiRange{
10✔
4772
                                        val:       val,
10✔
4773
                                        inclusive: true,
10✔
4774
                                },
10✔
4775
                        }
10✔
4776
                }
10✔
4777
        case GT:
18✔
4778
                {
36✔
4779
                        newRange = &typedValueRange{
18✔
4780
                                lRange: &typedValueSemiRange{
18✔
4781
                                        val: val,
18✔
4782
                                },
18✔
4783
                        }
18✔
4784
                }
18✔
4785
        case GE:
16✔
4786
                {
32✔
4787
                        newRange = &typedValueRange{
16✔
4788
                                lRange: &typedValueSemiRange{
16✔
4789
                                        val:       val,
16✔
4790
                                        inclusive: true,
16✔
4791
                                },
16✔
4792
                        }
16✔
4793
                }
16✔
4794
        case NE:
12✔
4795
                {
24✔
4796
                        return nil
12✔
4797
                }
12✔
4798
        }
4799

4800
        if !ranged {
611✔
4801
                rangesByColID[colID] = newRange
304✔
4802
                return nil
304✔
4803
        }
304✔
4804

4805
        return currRange.refineWith(newRange)
3✔
4806
}
4807

4808
func cmpSatisfiesOp(cmp int, op CmpOperator) bool {
13,902✔
4809
        switch {
13,902✔
4810
        case cmp == 0:
1,161✔
4811
                {
2,322✔
4812
                        return op == EQ || op == LE || op == GE
1,161✔
4813
                }
1,161✔
4814
        case cmp < 0:
6,486✔
4815
                {
12,972✔
4816
                        return op == NE || op == LT || op == LE
6,486✔
4817
                }
6,486✔
4818
        case cmp > 0:
6,255✔
4819
                {
12,510✔
4820
                        return op == NE || op == GT || op == GE
6,255✔
4821
                }
6,255✔
4822
        }
4823
        return false
×
4824
}
4825

4826
type BinBoolExp struct {
4827
        op          LogicOperator
4828
        left, right ValueExp
4829
}
4830

4831
func NewBinBoolExp(op LogicOperator, lrexp, rrexp ValueExp) *BinBoolExp {
18✔
4832
        bexp := &BinBoolExp{
18✔
4833
                op: op,
18✔
4834
        }
18✔
4835

18✔
4836
        bexp.left = lrexp
18✔
4837
        bexp.right = rrexp
18✔
4838

18✔
4839
        return bexp
18✔
4840
}
18✔
4841

4842
func (bexp *BinBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
19✔
4843
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
4844
        if err != nil {
19✔
4845
                return AnyType, err
×
4846
        }
×
4847

4848
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
19✔
4849
        if err != nil {
21✔
4850
                return AnyType, err
2✔
4851
        }
2✔
4852

4853
        return BooleanType, nil
17✔
4854
}
4855

4856
func (bexp *BinBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
22✔
4857
        if t != BooleanType {
25✔
4858
                return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
3✔
4859
        }
3✔
4860

4861
        err := bexp.left.requiresType(BooleanType, cols, params, implicitTable)
19✔
4862
        if err != nil {
20✔
4863
                return err
1✔
4864
        }
1✔
4865

4866
        err = bexp.right.requiresType(BooleanType, cols, params, implicitTable)
18✔
4867
        if err != nil {
18✔
4868
                return err
×
4869
        }
×
4870

4871
        return nil
18✔
4872
}
4873

4874
func (bexp *BinBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
570✔
4875
        rlexp, err := bexp.left.substitute(params)
570✔
4876
        if err != nil {
570✔
4877
                return nil, err
×
4878
        }
×
4879

4880
        rrexp, err := bexp.right.substitute(params)
570✔
4881
        if err != nil {
570✔
4882
                return nil, err
×
4883
        }
×
4884

4885
        bexp.left = rlexp
570✔
4886
        bexp.right = rrexp
570✔
4887

570✔
4888
        return bexp, nil
570✔
4889
}
4890

4891
func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
535✔
4892
        vl, err := bexp.left.reduce(tx, row, implicitTable)
535✔
4893
        if err != nil {
536✔
4894
                return nil, err
1✔
4895
        }
1✔
4896

4897
        bl, isBool := vl.(*Bool)
534✔
4898
        if !isBool {
534✔
4899
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
4900
        }
×
4901

4902
        // short-circuit evaluation
4903
        if (bl.val && bexp.op == Or) || (!bl.val && bexp.op == And) {
710✔
4904
                return &Bool{val: bl.val}, nil
176✔
4905
        }
176✔
4906

4907
        vr, err := bexp.right.reduce(tx, row, implicitTable)
358✔
4908
        if err != nil {
359✔
4909
                return nil, err
1✔
4910
        }
1✔
4911

4912
        br, isBool := vr.(*Bool)
357✔
4913
        if !isBool {
357✔
4914
                return nil, fmt.Errorf("%w (expecting boolean value)", ErrInvalidValue)
×
4915
        }
×
4916

4917
        switch bexp.op {
357✔
4918
        case And:
334✔
4919
                {
668✔
4920
                        return &Bool{val: bl.val && br.val}, nil
334✔
4921
                }
334✔
4922
        case Or:
23✔
4923
                {
46✔
4924
                        return &Bool{val: bl.val || br.val}, nil
23✔
4925
                }
23✔
4926
        }
4927

4928
        return nil, ErrUnexpected
×
4929
}
4930

4931
func (bexp *BinBoolExp) selectors() []Selector {
2✔
4932
        return append(bexp.left.selectors(), bexp.right.selectors()...)
2✔
4933
}
2✔
4934

4935
func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
15✔
4936
        return &BinBoolExp{
15✔
4937
                op:    bexp.op,
15✔
4938
                left:  bexp.left.reduceSelectors(row, implicitTable),
15✔
4939
                right: bexp.right.reduceSelectors(row, implicitTable),
15✔
4940
        }
15✔
4941
}
15✔
4942

4943
func (bexp *BinBoolExp) isConstant() bool {
1✔
4944
        return bexp.left.isConstant() && bexp.right.isConstant()
1✔
4945
}
1✔
4946

4947
func (bexp *BinBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
150✔
4948
        if bexp.op == And {
287✔
4949
                err := bexp.left.selectorRanges(table, asTable, params, rangesByColID)
137✔
4950
                if err != nil {
137✔
4951
                        return err
×
4952
                }
×
4953

4954
                return bexp.right.selectorRanges(table, asTable, params, rangesByColID)
137✔
4955
        }
4956

4957
        lRanges := make(map[uint32]*typedValueRange)
13✔
4958
        rRanges := make(map[uint32]*typedValueRange)
13✔
4959

13✔
4960
        err := bexp.left.selectorRanges(table, asTable, params, lRanges)
13✔
4961
        if err != nil {
13✔
4962
                return err
×
4963
        }
×
4964

4965
        err = bexp.right.selectorRanges(table, asTable, params, rRanges)
13✔
4966
        if err != nil {
13✔
4967
                return err
×
4968
        }
×
4969

4970
        for colID, lr := range lRanges {
20✔
4971
                rr, ok := rRanges[colID]
7✔
4972
                if !ok {
9✔
4973
                        continue
2✔
4974
                }
4975

4976
                err = lr.extendWith(rr)
5✔
4977
                if err != nil {
5✔
4978
                        return err
×
4979
                }
×
4980

4981
                rangesByColID[colID] = lr
5✔
4982
        }
4983

4984
        return nil
13✔
4985
}
4986

4987
func (bexp *BinBoolExp) String() string {
31✔
4988
        return fmt.Sprintf("(%s %s %s)", bexp.left.String(), LogicOperatorToString(bexp.op), bexp.right.String())
31✔
4989
}
31✔
4990

4991
type ExistsBoolExp struct {
4992
        q DataSource
4993
}
4994

4995
func (bexp *ExistsBoolExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
4996
        return AnyType, fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
4997
}
1✔
4998

4999
func (bexp *ExistsBoolExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5000
        return fmt.Errorf("error inferring type in 'EXISTS' clause: %w", ErrNoSupported)
1✔
5001
}
1✔
5002

5003
func (bexp *ExistsBoolExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5004
        return bexp, nil
1✔
5005
}
1✔
5006

5007
func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5008
        return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported)
2✔
5009
}
2✔
5010

5011
func (bexp *ExistsBoolExp) selectors() []Selector {
1✔
5012
        return nil
1✔
5013
}
1✔
5014

5015
func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5016
        return bexp
1✔
5017
}
1✔
5018

5019
func (bexp *ExistsBoolExp) isConstant() bool {
2✔
5020
        return false
2✔
5021
}
2✔
5022

5023
func (bexp *ExistsBoolExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5024
        return nil
1✔
5025
}
1✔
5026

5027
func (bexp *ExistsBoolExp) String() string {
×
5028
        return ""
×
5029
}
×
5030

5031
type InSubQueryExp struct {
5032
        val   ValueExp
5033
        notIn bool
5034
        q     *SelectStmt
5035
}
5036

5037
func (bexp *InSubQueryExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
1✔
5038
        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5039
}
1✔
5040

5041
func (bexp *InSubQueryExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
1✔
5042
        return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
1✔
5043
}
1✔
5044

5045
func (bexp *InSubQueryExp) substitute(params map[string]interface{}) (ValueExp, error) {
1✔
5046
        return bexp, nil
1✔
5047
}
1✔
5048

5049
func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
2✔
5050
        return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported)
2✔
5051
}
2✔
5052

5053
func (bexp *InSubQueryExp) selectors() []Selector {
1✔
5054
        return bexp.val.selectors()
1✔
5055
}
1✔
5056

5057
func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
1✔
5058
        return bexp
1✔
5059
}
1✔
5060

5061
func (bexp *InSubQueryExp) isConstant() bool {
1✔
5062
        return false
1✔
5063
}
1✔
5064

5065
func (bexp *InSubQueryExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
1✔
5066
        return nil
1✔
5067
}
1✔
5068

5069
func (bexp *InSubQueryExp) String() string {
×
5070
        return ""
×
5071
}
×
5072

5073
// TODO: once InSubQueryExp is supported, this struct may become obsolete by creating a ListDataSource struct
5074
type InListExp struct {
5075
        val    ValueExp
5076
        notIn  bool
5077
        values []ValueExp
5078
}
5079

5080
func (bexp *InListExp) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
6✔
5081
        t, err := bexp.val.inferType(cols, params, implicitTable)
6✔
5082
        if err != nil {
7✔
5083
                return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
5084
        }
1✔
5085

5086
        for _, v := range bexp.values {
15✔
5087
                err = v.requiresType(t, cols, params, implicitTable)
10✔
5088
                if err != nil {
11✔
5089
                        return AnyType, fmt.Errorf("error inferring type in 'IN' clause: %w", err)
1✔
5090
                }
1✔
5091
        }
5092

5093
        return BooleanType, nil
4✔
5094
}
5095

5096
func (bexp *InListExp) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
2✔
5097
        _, err := bexp.inferType(cols, params, implicitTable)
2✔
5098
        if err != nil {
3✔
5099
                return err
1✔
5100
        }
1✔
5101

5102
        if t != BooleanType {
1✔
5103
                return fmt.Errorf("error inferring type in 'IN' clause: %w", ErrInvalidTypes)
×
5104
        }
×
5105

5106
        return nil
1✔
5107
}
5108

5109
func (bexp *InListExp) substitute(params map[string]interface{}) (ValueExp, error) {
115✔
5110
        val, err := bexp.val.substitute(params)
115✔
5111
        if err != nil {
115✔
5112
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5113
        }
×
5114

5115
        values := make([]ValueExp, len(bexp.values))
115✔
5116

115✔
5117
        for i, val := range bexp.values {
245✔
5118
                values[i], err = val.substitute(params)
130✔
5119
                if err != nil {
130✔
5120
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
×
5121
                }
×
5122
        }
5123

5124
        return &InListExp{
115✔
5125
                val:    val,
115✔
5126
                notIn:  bexp.notIn,
115✔
5127
                values: values,
115✔
5128
        }, nil
115✔
5129
}
5130

5131
func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) {
115✔
5132
        rval, err := bexp.val.reduce(tx, row, implicitTable)
115✔
5133
        if err != nil {
116✔
5134
                return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5135
        }
1✔
5136

5137
        var found bool
114✔
5138

114✔
5139
        for _, v := range bexp.values {
241✔
5140
                rv, err := v.reduce(tx, row, implicitTable)
127✔
5141
                if err != nil {
128✔
5142
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5143
                }
1✔
5144

5145
                r, err := rval.Compare(rv)
126✔
5146
                if err != nil {
127✔
5147
                        return nil, fmt.Errorf("error evaluating 'IN' clause: %w", err)
1✔
5148
                }
1✔
5149

5150
                if r == 0 {
140✔
5151
                        // TODO: short-circuit evaluation may be preferred when upfront static type inference is in place
15✔
5152
                        found = found || true
15✔
5153
                }
15✔
5154
        }
5155

5156
        return &Bool{val: found != bexp.notIn}, nil
112✔
5157
}
5158

5159
func (bexp *InListExp) selectors() []Selector {
1✔
5160
        selectors := make([]Selector, 0, len(bexp.values))
1✔
5161
        for _, v := range bexp.values {
4✔
5162
                selectors = append(selectors, v.selectors()...)
3✔
5163
        }
3✔
5164
        return append(bexp.val.selectors(), selectors...)
1✔
5165
}
5166

5167
func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp {
10✔
5168
        values := make([]ValueExp, len(bexp.values))
10✔
5169

10✔
5170
        for i, val := range bexp.values {
20✔
5171
                values[i] = val.reduceSelectors(row, implicitTable)
10✔
5172
        }
10✔
5173

5174
        return &InListExp{
10✔
5175
                val:    bexp.val.reduceSelectors(row, implicitTable),
10✔
5176
                values: values,
10✔
5177
        }
10✔
5178
}
5179

5180
func (bexp *InListExp) isConstant() bool {
1✔
5181
        return false
1✔
5182
}
1✔
5183

5184
func (bexp *InListExp) selectorRanges(table *Table, asTable string, params map[string]interface{}, rangesByColID map[uint32]*typedValueRange) error {
20✔
5185
        // TODO: may be determiined by smallest and bigggest value in the list
20✔
5186
        return nil
20✔
5187
}
20✔
5188

5189
func (bexp *InListExp) String() string {
1✔
5190
        values := make([]string, len(bexp.values))
1✔
5191
        for i, exp := range bexp.values {
5✔
5192
                values[i] = exp.String()
4✔
5193
        }
4✔
5194
        return fmt.Sprintf("%s IN (%s)", bexp.val.String(), strings.Join(values, ","))
1✔
5195
}
5196

5197
type FnDataSourceStmt struct {
5198
        fnCall *FnCall
5199
        as     string
5200
}
5201

5202
func (stmt *FnDataSourceStmt) readOnly() bool {
1✔
5203
        return true
1✔
5204
}
1✔
5205

5206
func (stmt *FnDataSourceStmt) requiredPrivileges() []SQLPrivilege {
1✔
5207
        return nil
1✔
5208
}
1✔
5209

5210
func (stmt *FnDataSourceStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
×
5211
        return tx, nil
×
5212
}
×
5213

5214
func (stmt *FnDataSourceStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5215
        return nil
1✔
5216
}
1✔
5217

5218
func (stmt *FnDataSourceStmt) Alias() string {
24✔
5219
        if stmt.as != "" {
26✔
5220
                return stmt.as
2✔
5221
        }
2✔
5222

5223
        switch strings.ToUpper(stmt.fnCall.fn) {
22✔
5224
        case DatabasesFnCall:
3✔
5225
                {
6✔
5226
                        return "databases"
3✔
5227
                }
3✔
5228
        case TablesFnCall:
5✔
5229
                {
10✔
5230
                        return "tables"
5✔
5231
                }
5✔
5232
        case TableFnCall:
×
5233
                {
×
5234
                        return "table"
×
5235
                }
×
5236
        case UsersFnCall:
7✔
5237
                {
14✔
5238
                        return "users"
7✔
5239
                }
7✔
5240
        case ColumnsFnCall:
3✔
5241
                {
6✔
5242
                        return "columns"
3✔
5243
                }
3✔
5244
        case IndexesFnCall:
2✔
5245
                {
4✔
5246
                        return "indexes"
2✔
5247
                }
2✔
5248
        case GrantsFnCall:
2✔
5249
                return "grants"
2✔
5250
        }
5251

5252
        // not reachable
5253
        return ""
×
5254
}
5255

5256
func (stmt *FnDataSourceStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (rowReader RowReader, err error) {
25✔
5257
        if stmt.fnCall == nil {
25✔
5258
                return nil, fmt.Errorf("%w: function is unspecified", ErrIllegalArguments)
×
5259
        }
×
5260

5261
        switch strings.ToUpper(stmt.fnCall.fn) {
25✔
5262
        case DatabasesFnCall:
5✔
5263
                {
10✔
5264
                        return stmt.resolveListDatabases(ctx, tx, params, scanSpecs)
5✔
5265
                }
5✔
5266
        case TablesFnCall:
5✔
5267
                {
10✔
5268
                        return stmt.resolveListTables(ctx, tx, params, scanSpecs)
5✔
5269
                }
5✔
5270
        case TableFnCall:
×
5271
                {
×
5272
                        return stmt.resolveShowTable(ctx, tx, params, scanSpecs)
×
5273
                }
×
5274
        case UsersFnCall:
7✔
5275
                {
14✔
5276
                        return stmt.resolveListUsers(ctx, tx, params, scanSpecs)
7✔
5277
                }
7✔
5278
        case ColumnsFnCall:
3✔
5279
                {
6✔
5280
                        return stmt.resolveListColumns(ctx, tx, params, scanSpecs)
3✔
5281
                }
3✔
5282
        case IndexesFnCall:
3✔
5283
                {
6✔
5284
                        return stmt.resolveListIndexes(ctx, tx, params, scanSpecs)
3✔
5285
                }
3✔
5286
        case GrantsFnCall:
2✔
5287
                {
4✔
5288
                        return stmt.resolveListGrants(ctx, tx, params, scanSpecs)
2✔
5289
                }
2✔
5290
        }
5291

5292
        return nil, fmt.Errorf("%w (%s)", ErrFunctionDoesNotExist, stmt.fnCall.fn)
×
5293
}
5294

5295
func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5296
        if len(stmt.fnCall.params) > 0 {
5✔
5297
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, DatabasesFnCall, len(stmt.fnCall.params))
×
5298
        }
×
5299

5300
        cols := make([]ColDescriptor, 1)
5✔
5301
        cols[0] = ColDescriptor{
5✔
5302
                Column: "name",
5✔
5303
                Type:   VarcharType,
5✔
5304
        }
5✔
5305

5✔
5306
        var dbs []string
5✔
5307

5✔
5308
        if tx.engine.multidbHandler == nil {
6✔
5309
                return nil, ErrUnspecifiedMultiDBHandler
1✔
5310
        } else {
5✔
5311
                dbs, err = tx.engine.multidbHandler.ListDatabases(ctx)
4✔
5312
                if err != nil {
4✔
5313
                        return nil, err
×
5314
                }
×
5315
        }
5316

5317
        values := make([][]ValueExp, len(dbs))
4✔
5318

4✔
5319
        for i, db := range dbs {
12✔
5320
                values[i] = []ValueExp{&Varchar{val: db}}
8✔
5321
        }
8✔
5322

5323
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
4✔
5324
}
5325

5326
func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
5✔
5327
        if len(stmt.fnCall.params) > 0 {
5✔
5328
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, TablesFnCall, len(stmt.fnCall.params))
×
5329
        }
×
5330

5331
        cols := make([]ColDescriptor, 1)
5✔
5332
        cols[0] = ColDescriptor{
5✔
5333
                Column: "name",
5✔
5334
                Type:   VarcharType,
5✔
5335
        }
5✔
5336

5✔
5337
        tables := tx.catalog.GetTables()
5✔
5338

5✔
5339
        values := make([][]ValueExp, len(tables))
5✔
5340

5✔
5341
        for i, t := range tables {
14✔
5342
                values[i] = []ValueExp{&Varchar{val: t.name}}
9✔
5343
        }
9✔
5344

5345
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
5✔
5346
}
5347

5348
func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
×
5349
        cols := []ColDescriptor{
×
5350
                {
×
5351
                        Column: "column_name",
×
5352
                        Type:   VarcharType,
×
5353
                },
×
5354
                {
×
5355
                        Column: "type_name",
×
5356
                        Type:   VarcharType,
×
5357
                },
×
5358
                {
×
5359
                        Column: "is_nullable",
×
5360
                        Type:   BooleanType,
×
5361
                },
×
5362
                {
×
5363
                        Column: "is_indexed",
×
5364
                        Type:   VarcharType,
×
5365
                },
×
5366
                {
×
5367
                        Column: "is_auto_increment",
×
5368
                        Type:   BooleanType,
×
5369
                },
×
5370
                {
×
5371
                        Column: "is_unique",
×
5372
                        Type:   BooleanType,
×
5373
                },
×
5374
        }
×
5375

×
5376
        tableName, _ := stmt.fnCall.params[0].reduce(tx, nil, "")
×
5377
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
×
5378
        if err != nil {
×
5379
                return nil, err
×
5380
        }
×
5381

5382
        values := make([][]ValueExp, len(table.cols))
×
5383

×
5384
        for i, c := range table.cols {
×
5385
                index := "NO"
×
5386

×
5387
                indexed, err := table.IsIndexed(c.Name())
×
5388
                if err != nil {
×
5389
                        return nil, err
×
5390
                }
×
5391
                if indexed {
×
5392
                        index = "YES"
×
5393
                }
×
5394

5395
                if table.PrimaryIndex().IncludesCol(c.ID()) {
×
5396
                        index = "PRIMARY KEY"
×
5397
                }
×
5398

5399
                var unique bool
×
5400
                for _, index := range table.GetIndexesByColID(c.ID()) {
×
5401
                        if index.IsUnique() && len(index.Cols()) == 1 {
×
5402
                                unique = true
×
5403
                                break
×
5404
                        }
5405
                }
5406

5407
                var maxLen string
×
5408

×
5409
                if c.MaxLen() > 0 && (c.Type() == VarcharType || c.Type() == BLOBType) {
×
5410
                        maxLen = fmt.Sprintf("(%d)", c.MaxLen())
×
5411
                }
×
5412

5413
                values[i] = []ValueExp{
×
5414
                        &Varchar{val: c.colName},
×
5415
                        &Varchar{val: c.Type() + maxLen},
×
5416
                        &Bool{val: c.IsNullable()},
×
5417
                        &Varchar{val: index},
×
5418
                        &Bool{val: c.IsAutoIncremental()},
×
5419
                        &Bool{val: unique},
×
5420
                }
×
5421
        }
5422

NEW
5423
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
×
5424
}
5425

5426
func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) {
7✔
5427
        if len(stmt.fnCall.params) > 0 {
7✔
5428
                return nil, fmt.Errorf("%w: function '%s' expect no parameters but %d were provided", ErrIllegalArguments, UsersFnCall, len(stmt.fnCall.params))
×
5429
        }
×
5430

5431
        cols := []ColDescriptor{
7✔
5432
                {
7✔
5433
                        Column: "name",
7✔
5434
                        Type:   VarcharType,
7✔
5435
                },
7✔
5436
                {
7✔
5437
                        Column: "permission",
7✔
5438
                        Type:   VarcharType,
7✔
5439
                },
7✔
5440
        }
7✔
5441

7✔
5442
        users, err := tx.ListUsers(ctx)
7✔
5443
        if err != nil {
7✔
NEW
5444
                return nil, err
×
UNCOV
5445
        }
×
5446

5447
        values := make([][]ValueExp, len(users))
7✔
5448
        for i, user := range users {
23✔
5449
                perm := user.Permission()
16✔
5450

16✔
5451
                values[i] = []ValueExp{
16✔
5452
                        &Varchar{val: user.Username()},
16✔
5453
                        &Varchar{val: perm},
16✔
5454
                }
16✔
5455
        }
16✔
5456
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
7✔
5457
}
5458

5459
func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5460
        if len(stmt.fnCall.params) != 1 {
3✔
5461
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, ColumnsFnCall)
×
5462
        }
×
5463

5464
        cols := []ColDescriptor{
3✔
5465
                {
3✔
5466
                        Column: "table",
3✔
5467
                        Type:   VarcharType,
3✔
5468
                },
3✔
5469
                {
3✔
5470
                        Column: "name",
3✔
5471
                        Type:   VarcharType,
3✔
5472
                },
3✔
5473
                {
3✔
5474
                        Column: "type",
3✔
5475
                        Type:   VarcharType,
3✔
5476
                },
3✔
5477
                {
3✔
5478
                        Column: "max_length",
3✔
5479
                        Type:   IntegerType,
3✔
5480
                },
3✔
5481
                {
3✔
5482
                        Column: "nullable",
3✔
5483
                        Type:   BooleanType,
3✔
5484
                },
3✔
5485
                {
3✔
5486
                        Column: "auto_increment",
3✔
5487
                        Type:   BooleanType,
3✔
5488
                },
3✔
5489
                {
3✔
5490
                        Column: "indexed",
3✔
5491
                        Type:   BooleanType,
3✔
5492
                },
3✔
5493
                {
3✔
5494
                        Column: "primary",
3✔
5495
                        Type:   BooleanType,
3✔
5496
                },
3✔
5497
                {
3✔
5498
                        Column: "unique",
3✔
5499
                        Type:   BooleanType,
3✔
5500
                },
3✔
5501
        }
3✔
5502

3✔
5503
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5504
        if err != nil {
3✔
5505
                return nil, err
×
5506
        }
×
5507

5508
        tableName, err := val.reduce(tx, nil, "")
3✔
5509
        if err != nil {
3✔
5510
                return nil, err
×
5511
        }
×
5512

5513
        if tableName.Type() != VarcharType {
3✔
5514
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5515
        }
×
5516

5517
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5518
        if err != nil {
3✔
5519
                return nil, err
×
5520
        }
×
5521

5522
        values := make([][]ValueExp, len(table.cols))
3✔
5523

3✔
5524
        for i, c := range table.cols {
11✔
5525
                indexed, err := table.IsIndexed(c.Name())
8✔
5526
                if err != nil {
8✔
5527
                        return nil, err
×
5528
                }
×
5529

5530
                var unique bool
8✔
5531
                for _, index := range table.indexesByColID[c.id] {
16✔
5532
                        if index.IsUnique() && len(index.Cols()) == 1 {
11✔
5533
                                unique = true
3✔
5534
                                break
3✔
5535
                        }
5536
                }
5537

5538
                values[i] = []ValueExp{
8✔
5539
                        &Varchar{val: table.name},
8✔
5540
                        &Varchar{val: c.colName},
8✔
5541
                        &Varchar{val: c.colType},
8✔
5542
                        &Integer{val: int64(c.MaxLen())},
8✔
5543
                        &Bool{val: c.IsNullable()},
8✔
5544
                        &Bool{val: c.autoIncrement},
8✔
5545
                        &Bool{val: indexed},
8✔
5546
                        &Bool{val: table.PrimaryIndex().IncludesCol(c.ID())},
8✔
5547
                        &Bool{val: unique},
8✔
5548
                }
8✔
5549
        }
5550

5551
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5552
}
5553

5554
func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
3✔
5555
        if len(stmt.fnCall.params) != 1 {
3✔
5556
                return nil, fmt.Errorf("%w: function '%s' expect table name as parameter", ErrIllegalArguments, IndexesFnCall)
×
5557
        }
×
5558

5559
        cols := []ColDescriptor{
3✔
5560
                {
3✔
5561
                        Column: "table",
3✔
5562
                        Type:   VarcharType,
3✔
5563
                },
3✔
5564
                {
3✔
5565
                        Column: "name",
3✔
5566
                        Type:   VarcharType,
3✔
5567
                },
3✔
5568
                {
3✔
5569
                        Column: "unique",
3✔
5570
                        Type:   BooleanType,
3✔
5571
                },
3✔
5572
                {
3✔
5573
                        Column: "primary",
3✔
5574
                        Type:   BooleanType,
3✔
5575
                },
3✔
5576
        }
3✔
5577

3✔
5578
        val, err := stmt.fnCall.params[0].substitute(params)
3✔
5579
        if err != nil {
3✔
5580
                return nil, err
×
5581
        }
×
5582

5583
        tableName, err := val.reduce(tx, nil, "")
3✔
5584
        if err != nil {
3✔
5585
                return nil, err
×
5586
        }
×
5587

5588
        if tableName.Type() != VarcharType {
3✔
5589
                return nil, fmt.Errorf("%w: expected '%s' for table name but type '%s' given instead", ErrIllegalArguments, VarcharType, tableName.Type())
×
5590
        }
×
5591

5592
        table, err := tx.catalog.GetTableByName(tableName.RawValue().(string))
3✔
5593
        if err != nil {
3✔
5594
                return nil, err
×
5595
        }
×
5596

5597
        values := make([][]ValueExp, len(table.indexes))
3✔
5598

3✔
5599
        for i, index := range table.indexes {
10✔
5600
                values[i] = []ValueExp{
7✔
5601
                        &Varchar{val: table.name},
7✔
5602
                        &Varchar{val: index.Name()},
7✔
5603
                        &Bool{val: index.unique},
7✔
5604
                        &Bool{val: index.IsPrimary()},
7✔
5605
                }
7✔
5606
        }
7✔
5607

5608
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
3✔
5609
}
5610

5611
func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) {
2✔
5612
        if len(stmt.fnCall.params) > 1 {
2✔
5613
                return nil, fmt.Errorf("%w: function '%s' expect at most one parameter of type %s", ErrIllegalArguments, GrantsFnCall, VarcharType)
×
5614
        }
×
5615

5616
        var username string
2✔
5617
        if len(stmt.fnCall.params) == 1 {
3✔
5618
                val, err := stmt.fnCall.params[0].substitute(params)
1✔
5619
                if err != nil {
1✔
5620
                        return nil, err
×
5621
                }
×
5622

5623
                userVal, err := val.reduce(tx, nil, "")
1✔
5624
                if err != nil {
1✔
5625
                        return nil, err
×
5626
                }
×
5627

5628
                if userVal.Type() != VarcharType {
1✔
5629
                        return nil, fmt.Errorf("%w: expected '%s' for username but type '%s' given instead", ErrIllegalArguments, VarcharType, userVal.Type())
×
5630
                }
×
5631
                username, _ = userVal.RawValue().(string)
1✔
5632
        }
5633

5634
        cols := []ColDescriptor{
2✔
5635
                {
2✔
5636
                        Column: "user",
2✔
5637
                        Type:   VarcharType,
2✔
5638
                },
2✔
5639
                {
2✔
5640
                        Column: "privilege",
2✔
5641
                        Type:   VarcharType,
2✔
5642
                },
2✔
5643
        }
2✔
5644

2✔
5645
        var err error
2✔
5646
        var users []User
2✔
5647

2✔
5648
        if tx.engine.multidbHandler == nil {
2✔
5649
                return nil, ErrUnspecifiedMultiDBHandler
×
5650
        } else {
2✔
5651
                users, err = tx.engine.multidbHandler.ListUsers(ctx)
2✔
5652
                if err != nil {
2✔
5653
                        return nil, err
×
5654
                }
×
5655
        }
5656

5657
        values := make([][]ValueExp, 0, len(users))
2✔
5658

2✔
5659
        for _, user := range users {
4✔
5660
                if username == "" || user.Username() == username {
4✔
5661
                        for _, p := range user.SQLPrivileges() {
6✔
5662
                                values = append(values, []ValueExp{
4✔
5663
                                        &Varchar{val: user.Username()},
4✔
5664
                                        &Varchar{val: string(p)},
4✔
5665
                                })
4✔
5666
                        }
4✔
5667
                }
5668
        }
5669

5670
        return NewValuesRowReader(tx, params, cols, true, stmt.Alias(), values)
2✔
5671
}
5672

5673
// DropTableStmt represents a statement to delete a table.
5674
type DropTableStmt struct {
5675
        table string
5676
}
5677

5678
func NewDropTableStmt(table string) *DropTableStmt {
6✔
5679
        return &DropTableStmt{table: table}
6✔
5680
}
6✔
5681

5682
func (stmt *DropTableStmt) readOnly() bool {
1✔
5683
        return false
1✔
5684
}
1✔
5685

5686
func (stmt *DropTableStmt) requiredPrivileges() []SQLPrivilege {
1✔
5687
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5688
}
1✔
5689

5690
func (stmt *DropTableStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5691
        return nil
1✔
5692
}
1✔
5693

5694
/*
5695
Exec executes the delete table statement.
5696
It the table exists, if not it does nothing.
5697
If the table exists, it deletes all the indexes and the table itself.
5698
Note that this is a soft delete of the index and table key,
5699
the data is not deleted, but the metadata is updated.
5700
*/
5701
func (stmt *DropTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
7✔
5702
        if !tx.catalog.ExistTable(stmt.table) {
8✔
5703
                return nil, ErrTableDoesNotExist
1✔
5704
        }
1✔
5705

5706
        table, err := tx.catalog.GetTableByName(stmt.table)
6✔
5707
        if err != nil {
6✔
5708
                return nil, err
×
5709
        }
×
5710

5711
        // delete table
5712
        mappedKey := MapKey(
6✔
5713
                tx.sqlPrefix(),
6✔
5714
                catalogTablePrefix,
6✔
5715
                EncodeID(DatabaseID),
6✔
5716
                EncodeID(table.id),
6✔
5717
        )
6✔
5718
        err = tx.delete(ctx, mappedKey)
6✔
5719
        if err != nil {
6✔
5720
                return nil, err
×
5721
        }
×
5722

5723
        // delete columns
5724
        cols := table.ColumnsByID()
6✔
5725
        for _, col := range cols {
26✔
5726
                mappedKey := MapKey(
20✔
5727
                        tx.sqlPrefix(),
20✔
5728
                        catalogColumnPrefix,
20✔
5729
                        EncodeID(DatabaseID),
20✔
5730
                        EncodeID(col.table.id),
20✔
5731
                        EncodeID(col.id),
20✔
5732
                        []byte(col.colType),
20✔
5733
                )
20✔
5734
                err = tx.delete(ctx, mappedKey)
20✔
5735
                if err != nil {
20✔
5736
                        return nil, err
×
5737
                }
×
5738
        }
5739

5740
        // delete checks
5741
        for name := range table.checkConstraints {
6✔
5742
                key := MapKey(
×
5743
                        tx.sqlPrefix(),
×
5744
                        catalogCheckPrefix,
×
5745
                        EncodeID(DatabaseID),
×
5746
                        EncodeID(table.id),
×
5747
                        []byte(name),
×
5748
                )
×
5749

×
5750
                if err := tx.delete(ctx, key); err != nil {
×
5751
                        return nil, err
×
5752
                }
×
5753
        }
5754

5755
        // delete indexes
5756
        for _, index := range table.indexes {
13✔
5757
                mappedKey := MapKey(
7✔
5758
                        tx.sqlPrefix(),
7✔
5759
                        catalogIndexPrefix,
7✔
5760
                        EncodeID(DatabaseID),
7✔
5761
                        EncodeID(table.id),
7✔
5762
                        EncodeID(index.id),
7✔
5763
                )
7✔
5764
                err = tx.delete(ctx, mappedKey)
7✔
5765
                if err != nil {
7✔
5766
                        return nil, err
×
5767
                }
×
5768

5769
                indexKey := MapKey(
7✔
5770
                        tx.sqlPrefix(),
7✔
5771
                        MappedPrefix,
7✔
5772
                        EncodeID(table.id),
7✔
5773
                        EncodeID(index.id),
7✔
5774
                )
7✔
5775
                err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
14✔
5776
                        return sqlTx.engine.store.DeleteIndex(indexKey)
7✔
5777
                })
7✔
5778
                if err != nil {
7✔
5779
                        return nil, err
×
5780
                }
×
5781
        }
5782

5783
        err = tx.catalog.deleteTable(table)
6✔
5784
        if err != nil {
6✔
5785
                return nil, err
×
5786
        }
×
5787

5788
        tx.mutatedCatalog = true
6✔
5789

6✔
5790
        return tx, nil
6✔
5791
}
5792

5793
// DropIndexStmt represents a statement to delete a table.
5794
type DropIndexStmt struct {
5795
        table string
5796
        cols  []string
5797
}
5798

5799
func NewDropIndexStmt(table string, cols []string) *DropIndexStmt {
4✔
5800
        return &DropIndexStmt{table: table, cols: cols}
4✔
5801
}
4✔
5802

5803
func (stmt *DropIndexStmt) readOnly() bool {
1✔
5804
        return false
1✔
5805
}
1✔
5806

5807
func (stmt *DropIndexStmt) requiredPrivileges() []SQLPrivilege {
1✔
5808
        return []SQLPrivilege{SQLPrivilegeDrop}
1✔
5809
}
1✔
5810

5811
func (stmt *DropIndexStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5812
        return nil
1✔
5813
}
1✔
5814

5815
/*
5816
Exec executes the delete index statement.
5817
If the index exists, it deletes it. Note that this is a soft delete of the index
5818
the data is not deleted, but the metadata is updated.
5819
*/
5820
func (stmt *DropIndexStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
6✔
5821
        if !tx.catalog.ExistTable(stmt.table) {
7✔
5822
                return nil, ErrTableDoesNotExist
1✔
5823
        }
1✔
5824

5825
        table, err := tx.catalog.GetTableByName(stmt.table)
5✔
5826
        if err != nil {
5✔
5827
                return nil, err
×
5828
        }
×
5829

5830
        cols := make([]*Column, len(stmt.cols))
5✔
5831

5✔
5832
        for i, colName := range stmt.cols {
10✔
5833
                col, err := table.GetColumnByName(colName)
5✔
5834
                if err != nil {
5✔
5835
                        return nil, err
×
5836
                }
×
5837

5838
                cols[i] = col
5✔
5839
        }
5840

5841
        index, err := table.GetIndexByName(indexName(table.name, cols))
5✔
5842
        if err != nil {
5✔
5843
                return nil, err
×
5844
        }
×
5845

5846
        // delete index
5847
        mappedKey := MapKey(
5✔
5848
                tx.sqlPrefix(),
5✔
5849
                catalogIndexPrefix,
5✔
5850
                EncodeID(DatabaseID),
5✔
5851
                EncodeID(table.id),
5✔
5852
                EncodeID(index.id),
5✔
5853
        )
5✔
5854
        err = tx.delete(ctx, mappedKey)
5✔
5855
        if err != nil {
5✔
5856
                return nil, err
×
5857
        }
×
5858

5859
        indexKey := MapKey(
5✔
5860
                tx.sqlPrefix(),
5✔
5861
                MappedPrefix,
5✔
5862
                EncodeID(table.id),
5✔
5863
                EncodeID(index.id),
5✔
5864
        )
5✔
5865

5✔
5866
        err = tx.addOnCommittedCallback(func(sqlTx *SQLTx) error {
9✔
5867
                return sqlTx.engine.store.DeleteIndex(indexKey)
4✔
5868
        })
4✔
5869
        if err != nil {
5✔
5870
                return nil, err
×
5871
        }
×
5872

5873
        err = table.deleteIndex(index)
5✔
5874
        if err != nil {
6✔
5875
                return nil, err
1✔
5876
        }
1✔
5877

5878
        tx.mutatedCatalog = true
4✔
5879

4✔
5880
        return tx, nil
4✔
5881
}
5882

5883
type SQLPrivilege string
5884

5885
const (
5886
        SQLPrivilegeSelect SQLPrivilege = "SELECT"
5887
        SQLPrivilegeCreate SQLPrivilege = "CREATE"
5888
        SQLPrivilegeInsert SQLPrivilege = "INSERT"
5889
        SQLPrivilegeUpdate SQLPrivilege = "UPDATE"
5890
        SQLPrivilegeDelete SQLPrivilege = "DELETE"
5891
        SQLPrivilegeDrop   SQLPrivilege = "DROP"
5892
        SQLPrivilegeAlter  SQLPrivilege = "ALTER"
5893
)
5894

5895
var allPrivileges = []SQLPrivilege{
5896
        SQLPrivilegeSelect,
5897
        SQLPrivilegeCreate,
5898
        SQLPrivilegeInsert,
5899
        SQLPrivilegeUpdate,
5900
        SQLPrivilegeDelete,
5901
        SQLPrivilegeDrop,
5902
        SQLPrivilegeAlter,
5903
}
5904

5905
func DefaultSQLPrivilegesForPermission(p Permission) []SQLPrivilege {
295✔
5906
        switch p {
295✔
5907
        case PermissionSysAdmin, PermissionAdmin, PermissionReadWrite:
284✔
5908
                return allPrivileges
284✔
5909
        case PermissionReadOnly:
11✔
5910
                return []SQLPrivilege{SQLPrivilegeSelect}
11✔
5911
        }
5912
        return nil
×
5913
}
5914

5915
type AlterPrivilegesStmt struct {
5916
        database   string
5917
        user       string
5918
        privileges []SQLPrivilege
5919
        isGrant    bool
5920
}
5921

5922
func (stmt *AlterPrivilegesStmt) readOnly() bool {
2✔
5923
        return false
2✔
5924
}
2✔
5925

5926
func (stmt *AlterPrivilegesStmt) requiredPrivileges() []SQLPrivilege {
2✔
5927
        return nil
2✔
5928
}
2✔
5929

5930
func (stmt *AlterPrivilegesStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
2✔
5931
        if tx.IsExplicitCloseRequired() {
3✔
5932
                return nil, fmt.Errorf("%w: user privileges modification can not be done within a transaction", ErrNonTransactionalStmt)
1✔
5933
        }
1✔
5934

5935
        if tx.engine.multidbHandler == nil {
1✔
5936
                return nil, ErrUnspecifiedMultiDBHandler
×
5937
        }
×
5938

5939
        var err error
1✔
5940
        if stmt.isGrant {
1✔
5941
                err = tx.engine.multidbHandler.GrantSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
×
5942
        } else {
1✔
5943
                err = tx.engine.multidbHandler.RevokeSQLPrivileges(ctx, stmt.database, stmt.user, stmt.privileges)
1✔
5944
        }
1✔
5945
        return nil, err
1✔
5946
}
5947

5948
func (stmt *AlterPrivilegesStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error {
1✔
5949
        return nil
1✔
5950
}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc