• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

freeeve / tinykvs / 21180415886

20 Jan 2026 05:06PM UTC coverage: 70.81% (-0.1%) from 70.948%
21180415886

push

github

freeeve
fix(compaction): add SSTable reference counting to prevent use-after-close

Add reference counting to SSTables to prevent concurrent readers from
encountering "file already closed" errors during compaction.

Changes:
- Add refs and markedForRemoval fields to SSTable struct
- Add IncRef/DecRef/MarkForRemoval methods for safe lifecycle management
- Update reader.Get to hold refs while accessing SSTables
- Update ScanPrefix/ScanRange scanners to track and release refs
- Replace direct Close+Remove with MarkForRemoval in compaction

Fixes TestConcurrentReadsDuringCompaction race condition.

50 of 53 new or added lines in 3 files covered. (94.34%)

760 existing lines in 12 files now uncovered.

5594 of 7900 relevant lines covered (70.81%)

405174.35 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.4
/cmd/tinykvs/shell_select.go
1
package main
2

3
import (
4
        "encoding/json"
5
        "fmt"
6
        "math"
7
        "os"
8
        "os/signal"
9
        "strconv"
10
        "strings"
11
        "sync/atomic"
12
        "syscall"
13
        "time"
14

15
        "github.com/blastrain/vitess-sqlparser/sqlparser"
16
        "github.com/freeeve/msgpck"
17
        "github.com/freeeve/tinykvs"
18
)
19

20
// aggType represents an aggregation function type.
21
type aggType int
22

23
const (
24
        aggCount aggType = iota
25
        aggSum
26
        aggAvg
27
        aggMin
28
        aggMax
29
)
30

31
// aggregator holds state for a streaming aggregation.
32
type aggregator struct {
33
        typ   aggType
34
        field string // field path for sum/avg/min/max (empty for count)
35
        alias string // display name
36

37
        // streaming state
38
        count    int64
39
        sum      float64
40
        min      float64
41
        max      float64
42
        hasValue bool
43
}
44

45
func (a *aggregator) update(val tinykvs.Value) {
83✔
46
        a.count++
83✔
47

83✔
48
        if a.field == "" {
107✔
49
                return // count() doesn't need field extraction
24✔
50
        }
24✔
51

52
        // Extract numeric value from field
53
        num, ok := a.extractNumeric(val)
59✔
54
        if !ok {
69✔
55
                return
10✔
56
        }
10✔
57

58
        a.sum += num
49✔
59
        if !a.hasValue {
68✔
60
                a.min = num
19✔
61
                a.max = num
19✔
62
                a.hasValue = true
19✔
63
        } else {
49✔
64
                if num < a.min {
40✔
65
                        a.min = num
10✔
66
                }
10✔
67
                if num > a.max {
48✔
68
                        a.max = num
18✔
69
                }
18✔
70
        }
71
}
72

73
func (a *aggregator) extractNumeric(val tinykvs.Value) (float64, bool) {
59✔
74
        // For non-record types, use the value directly
59✔
75
        if a.field == "" {
59✔
76
                switch val.Type {
×
77
                case tinykvs.ValueTypeInt64:
×
78
                        return float64(val.Int64), true
×
79
                case tinykvs.ValueTypeFloat64:
×
80
                        return val.Float64, true
×
81
                default:
×
82
                        return 0, false
×
83
                }
84
        }
85

86
        // For record/msgpack types, extract the field
87
        var record map[string]any
59✔
88
        var err error
59✔
89
        switch val.Type {
59✔
90
        case tinykvs.ValueTypeRecord:
×
UNCOV
91
                record = val.Record
×
92
        case tinykvs.ValueTypeMsgpack:
59✔
93
                record, err = msgpck.UnmarshalMapStringAny(val.Bytes, false)
59✔
94
                if err != nil {
59✔
95
                        return 0, false
×
96
                }
×
UNCOV
97
        default:
×
UNCOV
98
                return 0, false
×
99
        }
100
        if record == nil {
59✔
UNCOV
101
                return 0, false
×
UNCOV
102
        }
×
103

104
        fieldVal, ok := extractNestedField(record, a.field)
59✔
105
        if !ok {
69✔
106
                return 0, false
10✔
107
        }
10✔
108

109
        switch v := fieldVal.(type) {
49✔
UNCOV
110
        case int:
×
UNCOV
111
                return float64(v), true
×
112
        case int8:
×
113
                return float64(v), true
×
114
        case int16:
×
115
                return float64(v), true
×
116
        case int32:
×
117
                return float64(v), true
×
118
        case int64:
11✔
119
                return float64(v), true
11✔
120
        case uint:
×
121
                return float64(v), true
×
122
        case uint8:
×
123
                return float64(v), true
×
124
        case uint16:
×
125
                return float64(v), true
×
126
        case uint32:
×
127
                return float64(v), true
×
128
        case uint64:
×
129
                return float64(v), true
×
UNCOV
130
        case float32:
×
UNCOV
131
                return float64(v), true
×
132
        case float64:
38✔
133
                return v, true
38✔
UNCOV
134
        default:
×
UNCOV
135
                return 0, false
×
136
        }
137
}
138

139
func (a *aggregator) result() string {
32✔
140
        switch a.typ {
32✔
141
        case aggCount:
9✔
142
                return fmt.Sprintf("%d", a.count)
9✔
143
        case aggSum:
8✔
144
                if !a.hasValue {
10✔
145
                        return "NULL"
2✔
146
                }
2✔
147
                return formatNumber(a.sum)
6✔
148
        case aggAvg:
7✔
149
                if a.count == 0 || !a.hasValue {
9✔
150
                        return "NULL"
2✔
151
                }
2✔
152
                return formatNumber(a.sum / float64(a.count))
5✔
153
        case aggMin:
4✔
154
                if !a.hasValue {
4✔
UNCOV
155
                        return "NULL"
×
UNCOV
156
                }
×
157
                return formatNumber(a.min)
4✔
158
        case aggMax:
4✔
159
                if !a.hasValue {
4✔
UNCOV
160
                        return "NULL"
×
161
                }
×
162
                return formatNumber(a.max)
4✔
UNCOV
163
        default:
×
UNCOV
164
                return "NULL"
×
165
        }
166
}
167

168
func formatNumber(f float64) string {
19✔
169
        if f == math.Trunc(f) {
34✔
170
                return fmt.Sprintf("%.0f", f)
15✔
171
        }
15✔
172
        return fmt.Sprintf("%g", f)
4✔
173
}
174

175
// valueFilter represents a filter condition on a value field
176
type valueFilter struct {
177
        field    string // e.g., "ttl", "name.first"
178
        operator string // "=", "like", ">", "<", ">=", "<="
179
        value    string // the comparison value
180
}
181

182
func (vf *valueFilter) matches(val tinykvs.Value) bool {
×
183
        // Extract the record
×
184
        var record map[string]any
×
185
        var err error
×
186
        switch val.Type {
×
187
        case tinykvs.ValueTypeRecord:
×
188
                record = val.Record
×
189
        case tinykvs.ValueTypeMsgpack:
×
190
                record, err = msgpck.UnmarshalMapStringAny(val.Bytes, false)
×
191
                if err != nil {
×
UNCOV
192
                        return false
×
UNCOV
193
                }
×
UNCOV
194
        default:
×
195
                return false
×
196
        }
197

198
        // Get field value
UNCOV
199
        fieldVal, ok := extractNestedField(record, vf.field)
×
200
        if !ok {
×
201
                return false
×
202
        }
×
203

204
        fieldStr := fmt.Sprintf("%v", fieldVal)
×
205

×
206
        switch vf.operator {
×
207
        case "=":
×
208
                return fieldStr == vf.value
×
209
        case "!=", "<>":
×
210
                return fieldStr != vf.value
×
211
        case "like":
×
212
                // Only prefix matching supported
×
213
                if strings.HasSuffix(vf.value, "%") {
×
214
                        prefix := vf.value[:len(vf.value)-1]
×
215
                        return strings.HasPrefix(fieldStr, prefix)
×
216
                }
×
217
                return fieldStr == vf.value
×
218
        case ">":
×
219
                return fieldStr > vf.value
×
220
        case "<":
×
221
                return fieldStr < vf.value
×
222
        case ">=":
×
223
                return fieldStr >= vf.value
×
UNCOV
224
        case "<=":
×
UNCOV
225
                return fieldStr <= vf.value
×
UNCOV
226
        default:
×
UNCOV
227
                return true
×
228
        }
229
}
230

231
// selectContext holds state for a SELECT query execution.
232
type selectContext struct {
233
        keyEquals    string
234
        keyPrefix    string
235
        keyStart     string
236
        keyEnd       string
237
        valueFilters []*valueFilter
238
        limit        int
239
        fields       []string
240
        aggs         []*aggregator
241
        headers      []string
242
        orderBy      []SortOrder
243
        scanned      int64
244
        scanStats    tinykvs.ScanStats
245
        matchCount   int
246
        bufferedRows [][]string
247
        lastProgress time.Time
248
        startTime    time.Time
249
        interrupted  int32
250
        scanErr      error
251
}
252

253
func (s *Shell) handleSelect(stmt *sqlparser.Select, orderBy []SortOrder) {
132✔
254
        ctx := s.parseSelectStatement(stmt, orderBy)
132✔
255

132✔
256
        sigChan := setupInterruptHandler(&ctx.interrupted)
132✔
257
        defer signal.Stop(sigChan)
132✔
258

132✔
259
        progressCallback := ctx.createProgressCallback()
132✔
260
        processRow := ctx.createRowProcessor()
132✔
261
        safeProcessRow := ctx.wrapRowProcessor(processRow)
132✔
262

132✔
263
        err := s.executeScan(ctx, safeProcessRow, progressCallback)
132✔
264

132✔
265
        ctx.clearProgressLine()
132✔
266

132✔
267
        wasInterrupted := atomic.LoadInt32(&ctx.interrupted) != 0
132✔
268
        if wasInterrupted {
132✔
UNCOV
269
                fmt.Fprintf(os.Stderr, "\r%s\r", strings.Repeat(" ", 80))
×
UNCOV
270
                fmt.Println("^C")
×
UNCOV
271
        }
×
272

273
        if ctx.scanErr != nil {
132✔
UNCOV
274
                fmt.Printf("Scan error: %v\n", ctx.scanErr)
×
UNCOV
275
                return
×
UNCOV
276
        }
×
277
        if err != nil {
136✔
278
                fmt.Printf("Error: %v\n", err)
4✔
279
                return
4✔
280
        }
4✔
281

282
        ctx.renderResults()
128✔
283
        ctx.reportStatistics(wasInterrupted)
128✔
284
}
285

286
func (s *Shell) parseSelectStatement(stmt *sqlparser.Select, orderBy []SortOrder) *selectContext {
132✔
287
        ctx := &selectContext{
132✔
288
                limit:     100,
132✔
289
                orderBy:   orderBy,
132✔
290
                startTime: time.Now(),
132✔
291
        }
132✔
292

132✔
293
        ctx.fields, ctx.aggs = parseSelectExpressions(stmt.SelectExprs)
132✔
294

132✔
295
        if stmt.Limit != nil && stmt.Limit.Rowcount != nil {
138✔
296
                if val, ok := stmt.Limit.Rowcount.(*sqlparser.SQLVal); ok {
12✔
297
                        if n, err := strconv.Atoi(string(val.Val)); err == nil {
12✔
298
                                ctx.limit = n
6✔
299
                        }
6✔
300
                }
301
        }
302

303
        if stmt.Where != nil {
231✔
304
                s.parseWhere(stmt.Where.Expr, &ctx.keyEquals, &ctx.keyPrefix, &ctx.keyStart, &ctx.keyEnd, &ctx.valueFilters)
99✔
305
        }
99✔
306

307
        if len(ctx.fields) > 0 {
163✔
308
                ctx.headers = append([]string{"k"}, ctx.fields...)
31✔
309
        } else {
132✔
310
                ctx.headers = []string{"k", "v"}
101✔
311
        }
101✔
312

313
        return ctx
132✔
314
}
315

316
func parseSelectExpressions(exprs sqlparser.SelectExprs) ([]string, []*aggregator) {
132✔
317
        var fields []string
132✔
318
        var aggs []*aggregator
132✔
319

132✔
320
        for _, expr := range exprs {
285✔
321
                switch e := expr.(type) {
153✔
322
                case *sqlparser.AliasedExpr:
71✔
323
                        if funcExpr, ok := e.Expr.(*sqlparser.FuncExpr); ok {
100✔
324
                                if agg := parseAggregateFunc(funcExpr); agg != nil {
58✔
325
                                        aggs = append(aggs, agg)
29✔
326
                                        continue
29✔
327
                                }
328
                        }
329
                        if col, ok := e.Expr.(*sqlparser.ColName); ok {
84✔
330
                                qualifier := strings.ToLower(col.Qualifier.Name.String())
42✔
331
                                fieldName := col.Name.String()
42✔
332
                                if qualifier == "v" {
76✔
333
                                        fields = append(fields, fieldName)
34✔
334
                                } else if qualifier != "" && qualifier != "kv" {
50✔
335
                                        fields = append(fields, qualifier+"."+fieldName)
8✔
336
                                }
8✔
337
                        }
338
                case *sqlparser.StarExpr:
82✔
339
                        fields = nil
82✔
340
                }
341
        }
342
        return fields, aggs
132✔
343
}
344

345
func setupInterruptHandler(interrupted *int32) chan os.Signal {
132✔
346
        sigChan := make(chan os.Signal, 1)
132✔
347
        signal.Notify(sigChan, syscall.SIGINT)
132✔
348
        go func() {
264✔
349
                <-sigChan
132✔
350
                atomic.StoreInt32(interrupted, 1)
132✔
351
        }()
132✔
352
        return sigChan
132✔
353
}
354

355
func (ctx *selectContext) createProgressCallback() tinykvs.ScanProgress {
132✔
356
        return func(stats tinykvs.ScanStats) bool {
132✔
UNCOV
357
                if atomic.LoadInt32(&ctx.interrupted) != 0 {
×
UNCOV
358
                        return false
×
359
                }
×
360
                ctx.scanStats = stats
×
361
                if time.Since(ctx.lastProgress) > time.Second {
×
UNCOV
362
                        elapsed := time.Since(ctx.startTime)
×
UNCOV
363
                        rate := int64(float64(stats.KeysExamined) / elapsed.Seconds())
×
UNCOV
364
                        fmt.Fprintf(os.Stderr, "\rScanned %s keys (%s blocks) in %s (%s keys/sec)...    ",
×
UNCOV
365
                                formatIntCommas(stats.KeysExamined), formatIntCommas(stats.BlocksLoaded),
×
UNCOV
366
                                formatDuration(elapsed), formatIntCommas(rate))
×
UNCOV
367
                        ctx.lastProgress = time.Now()
×
UNCOV
368
                }
×
UNCOV
369
                return true
×
370
        }
371
}
372

373
func (ctx *selectContext) createRowProcessor() func([]byte, tinykvs.Value) bool {
132✔
374
        isAggregate := len(ctx.aggs) > 0
132✔
375
        hasOrderBy := len(ctx.orderBy) > 0
132✔
376
        hasValueFilters := len(ctx.valueFilters) > 0
132✔
377

132✔
378
        return func(key []byte, val tinykvs.Value) bool {
359✔
379
                if atomic.LoadInt32(&ctx.interrupted) != 0 {
227✔
UNCOV
380
                        return false
×
UNCOV
381
                }
×
382
                ctx.scanned++
227✔
383

227✔
384
                if hasValueFilters && time.Since(ctx.lastProgress) > time.Second {
227✔
385
                        ctx.printFilterProgress()
×
UNCOV
386
                }
×
387

388
                if !isAggregate && !hasOrderBy && ctx.matchCount >= ctx.limit {
229✔
389
                        return false
2✔
390
                }
2✔
391

392
                for _, vf := range ctx.valueFilters {
225✔
UNCOV
393
                        if !vf.matches(val) {
×
UNCOV
394
                                return true
×
UNCOV
395
                        }
×
396
                }
397

398
                if isAggregate {
285✔
399
                        for _, agg := range ctx.aggs {
143✔
400
                                agg.update(val)
83✔
401
                        }
83✔
402
                } else {
165✔
403
                        row := extractRowFields(key, val, ctx.fields)
165✔
404
                        ctx.bufferedRows = append(ctx.bufferedRows, row)
165✔
405
                        ctx.matchCount++
165✔
406
                }
165✔
407
                return true
225✔
408
        }
409
}
410

UNCOV
411
func (ctx *selectContext) printFilterProgress() {
×
UNCOV
412
        elapsed := time.Since(ctx.startTime)
×
UNCOV
413
        rate := int64(float64(ctx.scanned) / elapsed.Seconds())
×
UNCOV
414
        fmt.Fprintf(os.Stderr, "\rScanned %s keys (%s blocks) in %s (%s keys/sec), found %d matches...    ",
×
UNCOV
415
                formatIntCommas(ctx.scanned), formatIntCommas(ctx.scanStats.BlocksLoaded),
×
UNCOV
416
                formatDuration(elapsed), formatIntCommas(rate), ctx.matchCount)
×
UNCOV
417
        ctx.lastProgress = time.Now()
×
UNCOV
418
}
×
419

420
func (ctx *selectContext) wrapRowProcessor(processRow func([]byte, tinykvs.Value) bool) func([]byte, tinykvs.Value) bool {
132✔
421
        return func(key []byte, val tinykvs.Value) bool {
359✔
422
                defer func() {
454✔
423
                        if r := recover(); r != nil {
227✔
UNCOV
424
                                ctx.scanErr = fmt.Errorf("panic processing key %x: %v", key[:min(8, len(key))], r)
×
UNCOV
425
                        }
×
426
                }()
427
                return processRow(key, val)
227✔
428
        }
429
}
430

431
func (s *Shell) executeScan(ctx *selectContext, processRow func([]byte, tinykvs.Value) bool, progress tinykvs.ScanProgress) error {
132✔
432
        if ctx.keyEquals != "" {
206✔
433
                return s.executePointLookup(ctx, processRow)
74✔
434
        }
74✔
435
        if ctx.keyPrefix != "" {
73✔
436
                var err error
15✔
437
                ctx.scanStats, err = s.store.ScanPrefixWithStats([]byte(ctx.keyPrefix), processRow, progress)
15✔
438
                return err
15✔
439
        }
15✔
440
        if ctx.keyStart != "" && ctx.keyEnd != "" {
48✔
441
                return s.store.ScanRange([]byte(ctx.keyStart), []byte(ctx.keyEnd), processRow)
5✔
442
        }
5✔
443
        var err error
38✔
444
        ctx.scanStats, err = s.store.ScanPrefixWithStats(nil, processRow, progress)
38✔
445
        return err
38✔
446
}
447

448
func (s *Shell) executePointLookup(ctx *selectContext, processRow func([]byte, tinykvs.Value) bool) error {
74✔
449
        val, err := s.store.Get([]byte(ctx.keyEquals))
74✔
450
        if err == tinykvs.ErrKeyNotFound {
89✔
451
                if len(ctx.aggs) > 0 {
16✔
452
                        printAggregateResults(ctx.aggs)
1✔
453
                } else {
15✔
454
                        printTable(ctx.headers, nil)
14✔
455
                        fmt.Printf("(0 rows)\n")
14✔
456
                }
14✔
457
                return nil
15✔
458
        }
459
        if err != nil {
60✔
460
                return err
1✔
461
        }
1✔
462
        processRow([]byte(ctx.keyEquals), val)
58✔
463
        return nil
58✔
464
}
465

466
func (ctx *selectContext) clearProgressLine() {
132✔
467
        hasValueFilters := len(ctx.valueFilters) > 0
132✔
468
        if (hasValueFilters || ctx.scanned > 10000) && ctx.scanned > 0 {
132✔
UNCOV
469
                fmt.Fprintf(os.Stderr, "\r%s\r", strings.Repeat(" ", 80))
×
UNCOV
470
        }
×
471
}
472

473
func (ctx *selectContext) renderResults() {
128✔
474
        if len(ctx.aggs) > 0 {
147✔
475
                printAggregateResults(ctx.aggs)
19✔
476
                return
19✔
477
        }
19✔
478

479
        if len(ctx.orderBy) > 0 {
119✔
480
                SortRows(ctx.headers, ctx.bufferedRows, ctx.orderBy)
10✔
481
        }
10✔
482

483
        if len(ctx.bufferedRows) > ctx.limit {
109✔
484
                ctx.bufferedRows = ctx.bufferedRows[:ctx.limit]
×
485
        }
×
486
        ctx.matchCount = len(ctx.bufferedRows)
109✔
487

109✔
488
        printTable(ctx.headers, ctx.bufferedRows)
109✔
489
}
490

491
func (ctx *selectContext) reportStatistics(wasInterrupted bool) {
128✔
492
        if ctx.scanned == 0 && ctx.scanStats.BlocksLoaded == 0 {
150✔
493
                fmt.Printf("(%d rows)\n", ctx.matchCount)
22✔
494
                return
22✔
495
        }
22✔
496

497
        elapsed := time.Since(ctx.startTime)
106✔
498
        rate := float64(ctx.scanned) / elapsed.Seconds()
106✔
499

106✔
500
        if wasInterrupted {
106✔
UNCOV
501
                fmt.Printf("(%d rows) - interrupted, scanned %s keys (%s blocks) in %s (%s keys/sec)\n",
×
502
                        ctx.matchCount, formatIntCommas(ctx.scanned),
×
503
                        formatIntCommas(ctx.scanStats.BlocksLoaded), formatDuration(elapsed), formatIntCommas(int64(rate)))
×
504
                return
×
505
        }
×
506

507
        blockDetails := fmt.Sprintf("%d blocks", ctx.scanStats.BlocksLoaded)
106✔
508
        if ctx.scanStats.BlocksCacheHit > 0 || ctx.scanStats.BlocksDiskRead > 0 {
107✔
509
                blockDetails = fmt.Sprintf("%d blocks (%d cache, %d disk)",
1✔
510
                        ctx.scanStats.BlocksLoaded, ctx.scanStats.BlocksCacheHit, ctx.scanStats.BlocksDiskRead)
1✔
511
        }
1✔
512
        tableDetails := ""
106✔
513
        if ctx.scanStats.TablesChecked > 0 {
107✔
514
                tableDetails = fmt.Sprintf(", %d/%d tables", ctx.scanStats.TablesAdded, ctx.scanStats.TablesChecked)
1✔
515
        }
1✔
516
        fmt.Printf("(%d rows) scanned %s keys, %s%s, %s\n",
106✔
517
                ctx.matchCount, formatIntCommas(ctx.scanned), blockDetails, tableDetails, formatDuration(elapsed))
106✔
518
}
519

520
// printStreamingHeader prints column headers for streaming output
UNCOV
521
func printStreamingHeader(headers []string) {
×
UNCOV
522
        for i, h := range headers {
×
UNCOV
523
                if i > 0 {
×
UNCOV
524
                        fmt.Print("\t")
×
UNCOV
525
                }
×
UNCOV
526
                fmt.Print(h)
×
527
        }
UNCOV
528
        fmt.Println()
×
UNCOV
529
        // Print separator
×
UNCOV
530
        for i, h := range headers {
×
UNCOV
531
                if i > 0 {
×
UNCOV
532
                        fmt.Print("\t")
×
UNCOV
533
                }
×
UNCOV
534
                fmt.Print(strings.Repeat("-", len(h)))
×
535
        }
UNCOV
536
        fmt.Println()
×
537
}
538

539
// printStreamingRow prints a single row for streaming output
UNCOV
540
func printStreamingRow(row []string) {
×
UNCOV
541
        for i, cell := range row {
×
UNCOV
542
                if i > 0 {
×
UNCOV
543
                        fmt.Print("\t")
×
UNCOV
544
                }
×
545
                // Truncate long values
UNCOV
546
                if len(cell) > 60 {
×
UNCOV
547
                        fmt.Print(cell[:57] + "...")
×
UNCOV
548
                } else {
×
UNCOV
549
                        fmt.Print(cell)
×
UNCOV
550
                }
×
551
        }
UNCOV
552
        fmt.Println()
×
553
}
554

555
// printTable prints rows in DuckDB-style box format
556
func printTable(headers []string, rows [][]string) {
123✔
557
        if len(headers) == 0 {
123✔
UNCOV
558
                return
×
UNCOV
559
        }
×
560

561
        // Calculate column widths
562
        widths := make([]int, len(headers))
123✔
563
        for i, h := range headers {
380✔
564
                widths[i] = len(h)
257✔
565
        }
257✔
566
        for _, row := range rows {
288✔
567
                for i, cell := range row {
512✔
568
                        if i < len(widths) && len(cell) > widths[i] {
495✔
569
                                widths[i] = len(cell)
148✔
570
                        }
148✔
571
                }
572
        }
573

574
        // Cap column widths at 50 chars for readability
575
        for i := range widths {
380✔
576
                if widths[i] > 50 {
259✔
577
                        widths[i] = 50
2✔
578
                }
2✔
579
        }
580

581
        // Print top border
582
        printBoxLine(widths, "┌", "┬", "┐")
123✔
583

123✔
584
        // Print header row
123✔
585
        fmt.Print("│")
123✔
586
        for i, h := range headers {
380✔
587
                fmt.Printf(" %-*s │", widths[i], truncate(h, widths[i]))
257✔
588
        }
257✔
589
        fmt.Println()
123✔
590

123✔
591
        // Print header separator
123✔
592
        printBoxLine(widths, "├", "┼", "┤")
123✔
593

123✔
594
        // Print data rows
123✔
595
        for _, row := range rows {
288✔
596
                fmt.Print("│")
165✔
597
                for i := 0; i < len(headers); i++ {
512✔
598
                        cell := ""
347✔
599
                        if i < len(row) {
694✔
600
                                cell = row[i]
347✔
601
                        }
347✔
602
                        fmt.Printf(" %-*s │", widths[i], truncate(cell, widths[i]))
347✔
603
                }
604
                fmt.Println()
165✔
605
        }
606

607
        // Print bottom border
608
        printBoxLine(widths, "└", "┴", "┘")
123✔
609
}
610

611
func printBoxLine(widths []int, left, mid, right string) {
369✔
612
        fmt.Print(left)
369✔
613
        for i, w := range widths {
1,140✔
614
                fmt.Print(strings.Repeat("─", w+2))
771✔
615
                if i < len(widths)-1 {
1,173✔
616
                        fmt.Print(mid)
402✔
617
                }
402✔
618
        }
619
        fmt.Println(right)
369✔
620
}
621

622
func truncate(s string, maxLen int) string {
604✔
623
        if len(s) <= maxLen {
1,206✔
624
                return s
602✔
625
        }
602✔
626
        if maxLen <= 3 {
2✔
UNCOV
627
                return s[:maxLen]
×
UNCOV
628
        }
×
629
        return s[:maxLen-3] + "..."
2✔
630
}
631

632
func parseAggregateFunc(funcExpr *sqlparser.FuncExpr) *aggregator {
29✔
633
        funcName := strings.ToLower(funcExpr.Name.String())
29✔
634

29✔
635
        var typ aggType
29✔
636
        switch funcName {
29✔
637
        case "count":
8✔
638
                typ = aggCount
8✔
639
        case "sum":
7✔
640
                typ = aggSum
7✔
641
        case "avg":
6✔
642
                typ = aggAvg
6✔
643
        case "min":
4✔
644
                typ = aggMin
4✔
645
        case "max":
4✔
646
                typ = aggMax
4✔
UNCOV
647
        default:
×
UNCOV
648
                return nil
×
649
        }
650

651
        agg := &aggregator{typ: typ}
29✔
652

29✔
653
        // Extract field argument for sum/avg/min/max
29✔
654
        if len(funcExpr.Exprs) > 0 {
50✔
655
                if aliased, ok := funcExpr.Exprs[0].(*sqlparser.AliasedExpr); ok {
42✔
656
                        if col, ok := aliased.Expr.(*sqlparser.ColName); ok {
42✔
657
                                qualifier := strings.ToLower(col.Qualifier.Name.String())
21✔
658
                                fieldName := col.Name.String()
21✔
659

21✔
660
                                if qualifier == "v" {
40✔
661
                                        agg.field = fieldName
19✔
662
                                } else if qualifier != "" && qualifier != "kv" {
23✔
663
                                        agg.field = qualifier + "." + fieldName
2✔
664
                                }
2✔
665
                        }
666
                }
667
        }
668

669
        // Build alias for display
670
        if agg.field != "" {
50✔
671
                agg.alias = fmt.Sprintf("%s(v.%s)", funcName, agg.field)
21✔
672
        } else {
29✔
673
                agg.alias = fmt.Sprintf("%s()", funcName)
8✔
674
        }
8✔
675

676
        return agg
29✔
677
}
678

679
func printAggregateResults(aggs []*aggregator) {
20✔
680
        // Print header
20✔
681
        headers := make([]string, len(aggs))
20✔
682
        for i, agg := range aggs {
52✔
683
                headers[i] = agg.alias
32✔
684
        }
32✔
685
        fmt.Println(strings.Join(headers, " | "))
20✔
686

20✔
687
        // Print separator
20✔
688
        seps := make([]string, len(aggs))
20✔
689
        for i, agg := range aggs {
52✔
690
                seps[i] = strings.Repeat("-", len(agg.alias))
32✔
691
        }
32✔
692
        fmt.Println(strings.Join(seps, "-+-"))
20✔
693

20✔
694
        // Print values
20✔
695
        values := make([]string, len(aggs))
20✔
696
        for i, agg := range aggs {
52✔
697
                result := agg.result()
32✔
698
                // Pad to match header width
32✔
699
                if len(result) < len(agg.alias) {
61✔
700
                        result = strings.Repeat(" ", len(agg.alias)-len(result)) + result
29✔
701
                }
29✔
702
                values[i] = result
32✔
703
        }
704
        fmt.Println(strings.Join(values, " | "))
20✔
705
}
706

707
// extractRowFields extracts field values as strings for tabular display
708
func extractRowFields(key []byte, val tinykvs.Value, fields []string) []string {
166✔
709
        keyStr := formatKey(key)
166✔
710

166✔
711
        if len(fields) == 0 {
292✔
712
                // SELECT * - return key and full value
126✔
713
                return []string{keyStr, formatValue(val)}
126✔
714
        }
126✔
715

716
        // Extract specific fields
717
        row := []string{keyStr}
40✔
718

40✔
719
        var record map[string]any
40✔
720
        var err error
40✔
721
        switch val.Type {
40✔
722
        case tinykvs.ValueTypeRecord:
2✔
723
                record = val.Record
2✔
724
        case tinykvs.ValueTypeMsgpack:
36✔
725
                record, err = msgpck.UnmarshalMapStringAny(val.Bytes, false)
36✔
726
                if err != nil {
36✔
727
                        // Can't decode, return NULLs for all fields
×
728
                        for range fields {
×
UNCOV
729
                                row = append(row, "NULL")
×
UNCOV
730
                        }
×
UNCOV
731
                        return row
×
732
                }
733
        default:
2✔
734
                // Non-record type, return NULLs for all fields
2✔
735
                for range fields {
4✔
736
                        row = append(row, "NULL")
2✔
737
                }
2✔
738
                return row
2✔
739
        }
740

741
        for _, field := range fields {
93✔
742
                if v, ok := extractNestedField(record, field); ok {
99✔
743
                        row = append(row, fmt.Sprintf("%v", v))
44✔
744
                } else {
55✔
745
                        row = append(row, "NULL")
11✔
746
                }
11✔
747
        }
748
        return row
38✔
749
}
750

751
func formatValue(val tinykvs.Value) string {
126✔
752
        switch val.Type {
126✔
UNCOV
753
        case tinykvs.ValueTypeInt64:
×
UNCOV
754
                return fmt.Sprintf("%d", val.Int64)
×
755
        case tinykvs.ValueTypeFloat64:
1✔
756
                return fmt.Sprintf("%f", val.Float64)
1✔
757
        case tinykvs.ValueTypeBool:
1✔
758
                return fmt.Sprintf("%t", val.Bool)
1✔
759
        case tinykvs.ValueTypeString, tinykvs.ValueTypeBytes:
83✔
760
                if len(val.Bytes) > 100 {
84✔
761
                        return string(val.Bytes[:100]) + "..."
1✔
762
                }
1✔
763
                return string(val.Bytes)
82✔
764
        case tinykvs.ValueTypeRecord:
1✔
765
                if val.Record == nil {
2✔
766
                        return "{}"
1✔
767
                }
1✔
768
                jsonBytes, _ := json.Marshal(val.Record)
×
769
                return string(jsonBytes)
×
770
        case tinykvs.ValueTypeMsgpack:
39✔
771
                record, err := msgpck.UnmarshalMapStringAny(val.Bytes, false)
39✔
772
                if err != nil {
39✔
773
                        return "(msgpack)"
×
774
                }
×
775
                jsonBytes, _ := json.Marshal(record)
39✔
776
                return string(jsonBytes)
39✔
777
        default:
1✔
778
                return "(unknown)"
1✔
779
        }
780
}
781

782
func (s *Shell) parseWhere(expr sqlparser.Expr, keyEquals, keyPrefix, keyStart, keyEnd *string, valueFilters *[]*valueFilter) {
120✔
783
        switch e := expr.(type) {
120✔
784
        case *sqlparser.ComparisonExpr:
110✔
785
                if col, ok := e.Left.(*sqlparser.ColName); ok {
220✔
786
                        qualifier := strings.ToLower(col.Qualifier.Name.String())
110✔
787
                        colName := strings.ToLower(col.Name.String())
110✔
788

110✔
789
                        // Check if this is a key filter (k = ..., k LIKE ..., etc.)
110✔
790
                        isKeyFilter := (colName == "k" && qualifier == "") ||
110✔
791
                                (colName == "k" && qualifier == "kv")
110✔
792

110✔
793
                        if isKeyFilter {
220✔
794
                                // Key filter
110✔
795
                                val, isHexPrefix := extractValueForLike(e.Right, e.Operator)
110✔
796
                                switch e.Operator {
110✔
797
                                case "=":
83✔
798
                                        *keyEquals = val
83✔
799
                                case "like":
20✔
800
                                        if isHexPrefix {
20✔
UNCOV
801
                                                *keyPrefix = val
×
802
                                        } else if strings.HasSuffix(val, "%") && !strings.Contains(val[:len(val)-1], "%") {
39✔
803
                                                *keyPrefix = val[:len(val)-1]
19✔
804
                                        } else {
20✔
805
                                                fmt.Println("Warning: LIKE only supports prefix matching (e.g., 'prefix%' or x'14%')")
1✔
806
                                        }
1✔
807
                                case ">=":
4✔
808
                                        *keyStart = val
4✔
809
                                case "<=":
3✔
810
                                        *keyEnd = val
3✔
811
                                }
UNCOV
812
                        } else {
×
UNCOV
813
                                // Value field filter: v.field, v.a.b, or just field (assume v.)
×
UNCOV
814
                                var fieldName string
×
UNCOV
815
                                if qualifier == "v" {
×
UNCOV
816
                                        // v.ttl → field is "ttl"
×
UNCOV
817
                                        fieldName = colName
×
UNCOV
818
                                } else if qualifier != "" && qualifier != "kv" {
×
UNCOV
819
                                        // Nested: parsed as ttl.sub → field is "ttl.sub"
×
UNCOV
820
                                        // Or v.address.city parsed as address.city
×
UNCOV
821
                                        fieldName = qualifier + "." + colName
×
UNCOV
822
                                } else {
×
UNCOV
823
                                        // No qualifier, assume it's a value field
×
UNCOV
824
                                        // e.g., just "ttl" means v.ttl
×
UNCOV
825
                                        fieldName = colName
×
UNCOV
826
                                }
×
827

UNCOV
828
                                val := extractValue(e.Right)
×
UNCOV
829
                                *valueFilters = append(*valueFilters, &valueFilter{
×
UNCOV
830
                                        field:    fieldName,
×
UNCOV
831
                                        operator: e.Operator,
×
UNCOV
832
                                        value:    val,
×
UNCOV
833
                                })
×
834
                        }
835
                }
836
        case *sqlparser.RangeCond:
7✔
837
                // BETWEEN
7✔
838
                if col, ok := e.Left.(*sqlparser.ColName); ok {
14✔
839
                        if strings.ToLower(col.Name.String()) == "k" {
14✔
840
                                *keyStart = extractValue(e.From)
7✔
841
                                *keyEnd = extractValue(e.To)
7✔
842
                        }
7✔
843
                }
844
        case *sqlparser.AndExpr:
2✔
845
                s.parseWhere(e.Left, keyEquals, keyPrefix, keyStart, keyEnd, valueFilters)
2✔
846
                s.parseWhere(e.Right, keyEquals, keyPrefix, keyStart, keyEnd, valueFilters)
2✔
847
        }
848
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc