• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bokwoon95 / sq / 9053178445

12 May 2024 05:43PM UTC coverage: 82.269% (-2.7%) from 85.007%
9053178445

push

github

bokwoon95
add support for static queries (raw SQL)

249 of 544 new or added lines in 2 files covered. (45.77%)

4 existing lines in 1 file now uncovered.

5591 of 6796 relevant lines covered (82.27%)

46.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.91
/row_column.go
1
package sq
2

3
import (
4
        "database/sql"
5
        "database/sql/driver"
6
        "encoding/json"
7
        "fmt"
8
        "path/filepath"
9
        "reflect"
10
        "runtime"
11
        "strconv"
12
        "strings"
13
        "time"
14

15
        "github.com/bokwoon95/sq/internal/googleuuid"
16
        "github.com/bokwoon95/sq/internal/pqarray"
17
)
18

19
// Row represents the state of a row after a call to rows.Next().
20
type Row struct {
21
        dialect      string
22
        sqlRows      *sql.Rows
23
        runningIndex int
24
        fields       []Field
25
        scanDest     []any
26

27
        // TODO: call Values using the go-mysql driver and check what the driver
28
        // returns for bool and time.Time (without calling parseTime). Then we need
29
        // to accomodate those cases into []byte handling below.
30
        // TODO: Then! we can finally take the new code for a spin.
31
        queryIsStatic bool
32
        columns       []string
33
        columnTypes   []*sql.ColumnType
34
        values        []any
35
        columnIndex   map[string]int
36
}
37

38
// Column returns the names of the columns returned by the query. This method
39
// can only be called in a rowmapper if it is paired with a raw SQL query e.g.
40
// Queryf("SELECT * FROM my_table"). Otherwise, an error will be returned.
41
func (row *Row) Columns() []string {
50✔
42
        if row.queryIsStatic {
84✔
43
                return row.columns
34✔
44
        }
34✔
45
        if row.sqlRows == nil {
20✔
46
                return nil
4✔
47
        }
4✔
48
        columns, err := row.sqlRows.Columns()
12✔
49
        if err != nil {
12✔
NEW
50
                panic(fmt.Errorf(callsite(1)+"sqlRows.Columns: %w", err))
×
51
        }
52
        return columns
12✔
53
}
54

55
// ColumnTypes returns the column types returned by the query. This method can
56
// only be called in a rowmapper if it is paired with a raw SQL query e.g.
57
// Queryf("SELECT * FROM my_table"). Otherwise, an error will be returned.
58
func (row *Row) ColumnTypes() []*sql.ColumnType {
32✔
59
        if row.queryIsStatic {
48✔
60
                return row.columnTypes
16✔
61
        }
16✔
62
        if row.sqlRows == nil {
20✔
63
                return nil
4✔
64
        }
4✔
65
        columnTypes, err := row.sqlRows.ColumnTypes()
12✔
66
        if err != nil {
12✔
NEW
67
                panic(fmt.Errorf(callsite(1)+"sqlRows.ColumnTypes: %w", err))
×
68
        }
69
        return columnTypes
12✔
70
}
71

72
// Values returns the values of the current row. This method can only be called
73
// in a rowmapper if it is paired with a raw SQL query e.g. Queryf("SELECT *
74
// FROM my_table"). Otherwise, an error will be returned.
75
func (row *Row) Values() []any {
50✔
76
        if row.queryIsStatic {
84✔
77
                values := make([]any, len(row.values))
34✔
78
                copy(values, row.values)
34✔
79
                return values
34✔
80
        }
34✔
81
        if row.sqlRows == nil {
20✔
82
                return nil
4✔
83
        }
4✔
84
        columns, err := row.sqlRows.Columns()
12✔
85
        if err != nil {
12✔
NEW
86
                panic(fmt.Errorf(callsite(1)+"sqlRows.Columns: %w", err))
×
87
        }
88
        values := make([]any, len(columns))
12✔
89
        scanDest := make([]any, len(columns))
12✔
90
        for i := range values {
228✔
91
                scanDest[i] = &values[i]
216✔
92
        }
216✔
93
        err = row.sqlRows.Scan(scanDest...)
12✔
94
        if err != nil {
12✔
NEW
95
                panic(fmt.Errorf(callsite(1)+"sqlRows.Scan: %w", err))
×
96
        }
97
        return values
12✔
98
}
99

100
// Value returns the value of the expression. It is intended for use cases
101
// where you only know the name of the column but not its type to scan into.
102
// The underlying type of the value is determined by the database driver you
103
// are using.
104
func (row *Row) Value(format string, values ...any) any {
56✔
105
        if row.queryIsStatic {
56✔
NEW
106
                index, ok := row.columnIndex[format]
×
NEW
107
                if !ok {
×
NEW
108
                        panic(fmt.Errorf(callsite(1)+"column %s is not present in query (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
109
                }
NEW
110
                return row.values[index]
×
111
        }
112
        if row.sqlRows == nil {
76✔
113
                var value any
20✔
114
                row.fields = append(row.fields, Expr(format, values...))
20✔
115
                row.scanDest = append(row.scanDest, &value)
20✔
116
                return nil
20✔
117
        }
20✔
118
        defer func() {
72✔
119
                row.runningIndex++
36✔
120
        }()
36✔
121
        scanDest := row.scanDest[row.runningIndex].(*any)
36✔
122
        return *scanDest
36✔
123
}
124

125
// Scan scans the expression into destPtr.
126
func (row *Row) Scan(destPtr any, format string, values ...any) {
224✔
127
        if row.queryIsStatic {
224✔
NEW
128
                panic(fmt.Errorf(callsite(1) + "cannot call Scan for static queries"))
×
129
        }
130
        row.scan(destPtr, Expr(format, values...), 1)
224✔
131
}
132

133
// ScanField scans the field into destPtr.
134
func (row *Row) ScanField(destPtr any, field Field) {
×
NEW
135
        if row.queryIsStatic {
×
NEW
136
                panic(fmt.Errorf(callsite(1) + "cannot call ScanField for static queries"))
×
137
        }
138
        row.scan(destPtr, field, 1)
×
139
}
140

141
func (row *Row) scan(destPtr any, field Field, skip int) {
224✔
142
        if row.sqlRows == nil {
308✔
143
                row.fields = append(row.fields, field)
84✔
144
                switch destPtr.(type) {
84✔
145
                case *bool, *sql.NullBool:
12✔
146
                        row.scanDest = append(row.scanDest, &sql.NullBool{})
12✔
147
                case *float64, *sql.NullFloat64:
12✔
148
                        row.scanDest = append(row.scanDest, &sql.NullFloat64{})
12✔
149
                case *int32, *sql.NullInt32:
12✔
150
                        row.scanDest = append(row.scanDest, &sql.NullInt32{})
12✔
151
                case *int, *int64, *sql.NullInt64:
24✔
152
                        row.scanDest = append(row.scanDest, &sql.NullInt64{})
24✔
153
                case *string, *sql.NullString:
12✔
154
                        row.scanDest = append(row.scanDest, &sql.NullString{})
12✔
155
                case *time.Time, *sql.NullTime:
12✔
156
                        row.scanDest = append(row.scanDest, &sql.NullTime{})
12✔
157
                default:
×
158
                        if reflect.TypeOf(destPtr).Kind() != reflect.Ptr {
×
159
                                panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr))
×
160
                        }
161
                        row.scanDest = append(row.scanDest, destPtr)
×
162
                }
163
                return
84✔
164
        }
165
        defer func() {
280✔
166
                row.runningIndex++
140✔
167
        }()
140✔
168
        switch destPtr := destPtr.(type) {
140✔
169
        case *bool:
16✔
170
                scanDest := row.scanDest[row.runningIndex].(*sql.NullBool)
16✔
171
                *destPtr = scanDest.Bool
16✔
172
        case *sql.NullBool:
4✔
173
                scanDest := row.scanDest[row.runningIndex].(*sql.NullBool)
4✔
174
                *destPtr = *scanDest
4✔
175
        case *float64:
16✔
176
                scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64)
16✔
177
                *destPtr = scanDest.Float64
16✔
178
        case *sql.NullFloat64:
4✔
179
                scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64)
4✔
180
                *destPtr = *scanDest
4✔
181
        case *int:
16✔
182
                scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64)
16✔
183
                *destPtr = int(scanDest.Int64)
16✔
184
        case *int32:
16✔
185
                scanDest := row.scanDest[row.runningIndex].(*sql.NullInt32)
16✔
186
                *destPtr = scanDest.Int32
16✔
187
        case *sql.NullInt32:
4✔
188
                scanDest := row.scanDest[row.runningIndex].(*sql.NullInt32)
4✔
189
                *destPtr = *scanDest
4✔
190
        case *int64:
16✔
191
                scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64)
16✔
192
                *destPtr = scanDest.Int64
16✔
193
        case *sql.NullInt64:
8✔
194
                scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64)
8✔
195
                *destPtr = *scanDest
8✔
196
        case *string:
16✔
197
                scanDest := row.scanDest[row.runningIndex].(*sql.NullString)
16✔
198
                *destPtr = scanDest.String
16✔
199
        case *sql.NullString:
4✔
200
                scanDest := row.scanDest[row.runningIndex].(*sql.NullString)
4✔
201
                *destPtr = *scanDest
4✔
202
        case *time.Time:
16✔
203
                scanDest := row.scanDest[row.runningIndex].(*sql.NullTime)
16✔
204
                *destPtr = scanDest.Time
16✔
205
        case *sql.NullTime:
4✔
206
                scanDest := row.scanDest[row.runningIndex].(*sql.NullTime)
4✔
207
                *destPtr = *scanDest
4✔
208
        default:
×
209
                destValue := reflect.ValueOf(destPtr).Elem()
×
NEW
210
                srcValue := reflect.ValueOf(row.scanDest[row.runningIndex]).Elem()
×
211
                destValue.Set(srcValue)
×
212
        }
213
}
214

215
// Array scans the array expression into destPtr. The destPtr must be a pointer
216
// to a []string, []int, []int64, []int32, []float64, []float32 or []bool.
217
func (row *Row) Array(destPtr any, format string, values ...any) {
×
NEW
218
        if row.queryIsStatic {
×
NEW
219
                panic(fmt.Errorf(callsite(1) + "cannot call Array for static queries"))
×
220
        }
221
        row.array(destPtr, Expr(format, values...), 1)
×
222
}
223

224
// ArrayField scans the array field into destPtr. The destPtr must be a pointer
225
// to a []string, []int, []int64, []int32, []float64, []float32 or []bool.
226
func (row *Row) ArrayField(destPtr any, field Array) {
112✔
227
        if row.queryIsStatic {
112✔
NEW
228
                panic(fmt.Errorf(callsite(1) + "cannot call ArrayField for static queries"))
×
229
        }
230
        row.array(destPtr, field, 1)
112✔
231
}
232

233
func (row *Row) array(destPtr any, field Array, skip int) {
112✔
234
        if row.sqlRows == nil {
140✔
235
                if reflect.TypeOf(destPtr).Kind() != reflect.Ptr {
28✔
236
                        panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr))
×
237
                }
238
                if row.dialect == DialectPostgres {
35✔
239
                        switch destPtr.(type) {
7✔
240
                        case *[]string, *[]int, *[]int64, *[]int32, *[]float64, *[]float32, *[]bool:
7✔
241
                                break
7✔
242
                        default:
×
243
                                panic(fmt.Errorf(callsite(skip+1)+"destptr (%T) must be either a pointer to a []string, []int, []int64, []int32, []float64, []float32 or []bool", destPtr))
×
244
                        }
245
                }
246
                row.fields = append(row.fields, field)
28✔
247
                row.scanDest = append(row.scanDest, &nullBytes{
28✔
248
                        dialect:     row.dialect,
28✔
249
                        displayType: displayTypeString,
28✔
250
                })
28✔
251
                return
28✔
252
        }
253
        defer func() {
168✔
254
                row.runningIndex++
84✔
255
        }()
84✔
256
        scanDest := row.scanDest[row.runningIndex].(*nullBytes)
84✔
257
        if !scanDest.valid {
87✔
258
                return
3✔
259
        }
3✔
260
        if row.dialect != DialectPostgres {
144✔
261
                err := json.Unmarshal(scanDest.bytes, destPtr)
63✔
262
                if err != nil {
63✔
263
                        panic(fmt.Errorf(callsite(skip+1)+"unmarshaling json %q into %T: %w", string(scanDest.bytes), destPtr, err))
×
264
                }
265
                return
63✔
266
        }
267
        switch destPtr := destPtr.(type) {
18✔
268
        case *[]string:
3✔
269
                var array pqarray.StringArray
3✔
270
                err := array.Scan(scanDest.bytes)
3✔
271
                if err != nil {
3✔
272
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to string array: %w", string(scanDest.bytes), err))
×
273
                }
274
                *destPtr = array
3✔
275
        case *[]int:
3✔
276
                var array pqarray.Int64Array
3✔
277
                err := array.Scan(scanDest.bytes)
3✔
278
                if err != nil {
3✔
279
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int64 array: %w", string(scanDest.bytes), err))
×
280
                }
281
                *destPtr = (*destPtr)[:cap(*destPtr)]
3✔
282
                if len(*destPtr) < len(array) {
6✔
283
                        *destPtr = make([]int, len(array))
3✔
284
                }
3✔
285
                *destPtr = (*destPtr)[:len(array)]
3✔
286
                for i, num := range array {
12✔
287
                        (*destPtr)[i] = int(num)
9✔
288
                }
9✔
289
        case *[]int64:
2✔
290
                var array pqarray.Int64Array
2✔
291
                err := array.Scan(scanDest.bytes)
2✔
292
                if err != nil {
2✔
293
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int64 array: %w", string(scanDest.bytes), err))
×
294
                }
295
                *destPtr = array
2✔
296
        case *[]int32:
2✔
297
                var array pqarray.Int32Array
2✔
298
                err := array.Scan(scanDest.bytes)
2✔
299
                if err != nil {
2✔
300
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int32 array: %w", string(scanDest.bytes), err))
×
301
                }
302
                *destPtr = array
2✔
303
        case *[]float64:
3✔
304
                var array pqarray.Float64Array
3✔
305
                err := array.Scan(scanDest.bytes)
3✔
306
                if err != nil {
3✔
307
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to float64 array: %w", string(scanDest.bytes), err))
×
308
                }
309
                *destPtr = array
3✔
310
        case *[]float32:
2✔
311
                var array pqarray.Float32Array
2✔
312
                err := array.Scan(scanDest.bytes)
2✔
313
                if err != nil {
2✔
314
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to float32 array: %w", string(scanDest.bytes), err))
×
315
                }
316
                *destPtr = array
2✔
317
        case *[]bool:
3✔
318
                var array pqarray.BoolArray
3✔
319
                err := array.Scan(scanDest.bytes)
3✔
320
                if err != nil {
3✔
321
                        panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to bool array: %w", string(scanDest.bytes), err))
×
322
                }
323
                *destPtr = array
3✔
324
        default:
×
325
                panic(fmt.Errorf(callsite(skip+1)+"destptr (%T) must be either a pointer to a []string, []int, []int64, []int32, []float64, []float32 or []bool", destPtr))
×
326
        }
327
}
328

329
// Bytes returns the []byte value of the expression.
330
func (row *Row) Bytes(format string, values ...any) []byte {
×
NEW
331
        if row.queryIsStatic {
×
NEW
332
                index, ok := row.columnIndex[format]
×
NEW
333
                if !ok {
×
NEW
334
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
335
                }
NEW
336
                value := row.values[index]
×
NEW
337
                switch value := value.(type) {
×
NEW
338
                case int64:
×
NEW
339
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not []byte", value))
×
NEW
340
                case float64:
×
NEW
341
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not []byte", value))
×
NEW
342
                case bool:
×
NEW
343
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not []byte", value))
×
NEW
344
                case []byte:
×
NEW
345
                        return value
×
NEW
346
                case string:
×
NEW
347
                        return []byte(value)
×
NEW
348
                case time.Time:
×
NEW
349
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not []byte", value))
×
NEW
350
                case nil:
×
NEW
351
                        return nil
×
NEW
352
                default:
×
NEW
353
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not []byte", value))
×
354
                }
355
        }
NEW
356
        if row.sqlRows == nil {
×
NEW
357
                row.fields = append(row.fields, Expr(format, values...))
×
NEW
358
                row.scanDest = append(row.scanDest, &nullBytes{
×
NEW
359
                        dialect: row.dialect,
×
NEW
360
                })
×
NEW
361
                return nil
×
NEW
362
        }
×
NEW
363
        defer func() {
×
NEW
364
                row.runningIndex++
×
NEW
365
        }()
×
NEW
366
        scanDest := row.scanDest[row.runningIndex].(*nullBytes)
×
NEW
367
        var b []byte
×
NEW
368
        if scanDest.valid {
×
NEW
369
                b = make([]byte, len(scanDest.bytes))
×
NEW
370
                copy(b, scanDest.bytes)
×
NEW
371
        }
×
NEW
372
        return b
×
373
}
374

375
// BytesField returns the []byte value of the field.
376
func (row *Row) BytesField(field Binary) []byte {
16✔
377
        if row.queryIsStatic {
16✔
NEW
378
                panic(fmt.Errorf(callsite(1) + "cannot call BytesField for static queries"))
×
379
        }
380
        if row.sqlRows == nil {
20✔
381
                row.fields = append(row.fields, field)
4✔
382
                row.scanDest = append(row.scanDest, &nullBytes{
4✔
383
                        dialect: row.dialect,
4✔
384
                })
4✔
385
                return nil
4✔
386
        }
4✔
387
        defer func() {
24✔
388
                row.runningIndex++
12✔
389
        }()
12✔
390
        scanDest := row.scanDest[row.runningIndex].(*nullBytes)
12✔
391
        var b []byte
12✔
392
        if scanDest.valid {
24✔
393
                b = make([]byte, len(scanDest.bytes))
12✔
394
                copy(b, scanDest.bytes)
12✔
395
        }
12✔
396
        return b
12✔
397
}
398

399
// == Bool == //
400

401
// Bool returns the bool value of the expression.
402
func (row *Row) Bool(format string, values ...any) bool {
12✔
403
        if row.queryIsStatic {
24✔
404
                index, ok := row.columnIndex[format]
12✔
405
                if !ok {
12✔
NEW
406
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
407
                }
408
                value := row.values[index]
12✔
409
                switch value := value.(type) {
12✔
NEW
410
                case int64:
×
NEW
411
                        if value == 1 {
×
NEW
412
                                return true
×
NEW
413
                        }
×
NEW
414
                        if value == 0 {
×
NEW
415
                                return false
×
NEW
416
                        }
×
NEW
417
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not bool", value))
×
NEW
418
                case float64:
×
NEW
419
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not bool", value))
×
420
                case bool:
9✔
421
                        return value
9✔
422
                case []byte:
3✔
423
                        // Special case: go-mysql-driver returns everything as []byte.
3✔
424
                        if string(value) == "1" {
6✔
425
                                return true
3✔
426
                        }
3✔
NEW
427
                        if string(value) == "0" {
×
NEW
428
                                return false
×
NEW
429
                        }
×
NEW
430
                        panic(fmt.Errorf(callsite(1)+"%#v is []byte, not bool", value))
×
NEW
431
                case string:
×
NEW
432
                        panic(fmt.Errorf(callsite(1)+"%q is string, not bool", value))
×
NEW
433
                case time.Time:
×
NEW
434
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not bool", value))
×
NEW
435
                case nil:
×
NEW
436
                        return false
×
NEW
437
                default:
×
NEW
438
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not bool", value))
×
439
                }
440
        }
441
        return row.NullBoolField(Expr(format, values...)).Bool
×
442
}
443

444
// BoolField returns the bool value of the field.
445
func (row *Row) BoolField(field Boolean) bool {
16✔
446
        if row.queryIsStatic {
16✔
NEW
447
                panic(fmt.Errorf(callsite(1) + "cannot call BoolField for static queries"))
×
448
        }
449
        return row.NullBoolField(field).Bool
16✔
450
}
451

452
// NullBool returns the sql.NullBool value of the expression.
453
func (row *Row) NullBool(format string, values ...any) sql.NullBool {
4✔
454
        if row.queryIsStatic {
8✔
455
                index, ok := row.columnIndex[format]
4✔
456
                if !ok {
4✔
NEW
457
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
458
                }
459
                value := row.values[index]
4✔
460
                switch value := value.(type) {
4✔
NEW
461
                case int64:
×
NEW
462
                        if value == 1 {
×
NEW
463
                                return sql.NullBool{Bool: true, Valid: true}
×
NEW
464
                        }
×
NEW
465
                        if value == 0 {
×
NEW
466
                                return sql.NullBool{Bool: false, Valid: true}
×
NEW
467
                        }
×
NEW
468
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not bool", value))
×
NEW
469
                case float64:
×
NEW
470
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not bool", value))
×
NEW
471
                case bool:
×
NEW
472
                        return sql.NullBool{Bool: value, Valid: true}
×
NEW
473
                case []byte:
×
NEW
474
                        // Special case: go-mysql-driver returns everything as []byte.
×
NEW
475
                        if string(value) == "1" {
×
NEW
476
                                return sql.NullBool{Bool: true, Valid: true}
×
NEW
477
                        }
×
NEW
478
                        if string(value) == "0" {
×
NEW
479
                                return sql.NullBool{Bool: false, Valid: true}
×
NEW
480
                        }
×
NEW
481
                        panic(fmt.Errorf(callsite(1)+"%d is []byte, not bool", value))
×
NEW
482
                case string:
×
NEW
483
                        panic(fmt.Errorf(callsite(1)+"%q is string, not bool", value))
×
NEW
484
                case time.Time:
×
NEW
485
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not bool", value))
×
486
                case nil:
4✔
487
                        return sql.NullBool{}
4✔
NEW
488
                default:
×
NEW
489
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not bool", value))
×
490
                }
491
        }
492
        return row.NullBoolField(Expr(format, values...))
×
493
}
494

495
// NullBoolField returns the sql.NullBool value of the field.
496
func (row *Row) NullBoolField(field Boolean) sql.NullBool {
16✔
497
        if row.queryIsStatic {
16✔
NEW
498
                panic(fmt.Errorf(callsite(1) + "cannot call NullBoolField for static queries"))
×
499
        }
500
        if row.sqlRows == nil {
20✔
501
                row.fields = append(row.fields, field)
4✔
502
                row.scanDest = append(row.scanDest, &sql.NullBool{})
4✔
503
                return sql.NullBool{}
4✔
504
        }
4✔
505
        defer func() {
24✔
506
                row.runningIndex++
12✔
507
        }()
12✔
508
        scanDest := row.scanDest[row.runningIndex].(*sql.NullBool)
12✔
509
        return *scanDest
12✔
510
}
511

512
// Enum scans the enum expression into destPtr.
513
func (row *Row) Enum(destPtr Enumeration, format string, values ...any) {
×
NEW
514
        if row.queryIsStatic {
×
NEW
515
                panic(fmt.Errorf(callsite(1) + "cannot call Enum for static queries"))
×
516
        }
517
        row.enum(destPtr, Expr(format, values...), 1)
×
518
}
519

520
// EnumField scans the enum field into destPtr.
521
func (row *Row) EnumField(destPtr Enumeration, field Enum) {
48✔
522
        if row.queryIsStatic {
48✔
NEW
523
                panic(fmt.Errorf(callsite(1) + "cannot call EnumField for static queries"))
×
524
        }
525
        row.enum(destPtr, field, 1)
48✔
526
}
527

528
func (row *Row) enum(destPtr Enumeration, field Enum, skip int) {
48✔
529
        if row.sqlRows == nil {
60✔
530
                destType := reflect.TypeOf(destPtr)
12✔
531
                if destType.Kind() != reflect.Ptr {
12✔
532
                        panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr))
×
533
                }
534
                row.fields = append(row.fields, field)
12✔
535
                switch destType.Elem().Kind() {
12✔
536
                case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
537
                        reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
538
                        reflect.String:
12✔
539
                        row.scanDest = append(row.scanDest, &sql.NullString{})
12✔
540
                default:
×
541
                        panic(fmt.Errorf(callsite(skip+1)+"underlying type of %[1]v is neither an integer or string (%[1]T)", destPtr))
×
542
                }
543
                return
12✔
544
        }
545
        defer func() {
72✔
546
                row.runningIndex++
36✔
547
        }()
36✔
548
        scanDest := row.scanDest[row.runningIndex].(*sql.NullString)
36✔
549
        names := destPtr.Enumerate()
36✔
550
        enumIndex := 0
36✔
551
        destValue := reflect.ValueOf(destPtr).Elem()
36✔
552
        if scanDest.Valid {
72✔
553
                enumIndex = getEnumIndex(scanDest.String, names, destValue.Type())
36✔
554
        }
36✔
555
        if enumIndex < 0 {
36✔
556
                panic(fmt.Errorf(callsite(skip+1)+"%q is not a valid %T", scanDest.String, destPtr))
×
557
        }
558
        switch destValue.Kind() {
36✔
559
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
12✔
560
                destValue.SetInt(int64(enumIndex))
12✔
561
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
12✔
562
                destValue.SetUint(uint64(enumIndex))
12✔
563
        case reflect.String:
12✔
564
                destValue.SetString(scanDest.String)
12✔
565
        }
566
}
567

568
// Float64 returns the float64 value of the expression.
569
func (row *Row) Float64(format string, values ...any) float64 {
12✔
570
        if row.queryIsStatic {
24✔
571
                index, ok := row.columnIndex[format]
12✔
572
                if !ok {
12✔
NEW
573
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
574
                }
575
                value := row.values[index]
12✔
576
                switch value := value.(type) {
12✔
NEW
577
                case int64:
×
NEW
578
                        return float64(value)
×
579
                case float64:
9✔
580
                        return value
9✔
NEW
581
                case bool:
×
NEW
582
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not float64", value))
×
583
                case []byte:
3✔
584
                        // Special case: go-mysql-driver returns everything as []byte.
3✔
585
                        n, err := strconv.ParseFloat(string(value), 64)
3✔
586
                        if err != nil {
3✔
NEW
587
                                panic(fmt.Errorf(callsite(1)+"%d is []byte, not float64", value))
×
588
                        }
589
                        return n
3✔
NEW
590
                case string:
×
NEW
591
                        panic(fmt.Errorf(callsite(1)+"%q is string, not float64", value))
×
NEW
592
                case time.Time:
×
NEW
593
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not float64", value))
×
NEW
594
                case nil:
×
NEW
595
                        return 0
×
NEW
596
                default:
×
NEW
597
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not float64", value))
×
598
                }
599
        }
600
        return row.NullFloat64Field(Expr(format, values...)).Float64
×
601
}
602

603
// Float64Field returns the float64 value of the field.
604
func (row *Row) Float64Field(field Number) float64 {
16✔
605
        if row.queryIsStatic {
16✔
NEW
606
                panic(fmt.Errorf(callsite(1) + "cannot call Float64Field for static queries"))
×
607
        }
608
        return row.NullFloat64Field(field).Float64
16✔
609
}
610

611
// NullFloat64 returns the sql.NullFloat64 valye of the expression.
612
func (row *Row) NullFloat64(format string, values ...any) sql.NullFloat64 {
4✔
613
        if row.queryIsStatic {
8✔
614
                index, ok := row.columnIndex[format]
4✔
615
                if !ok {
4✔
NEW
616
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
617
                }
618
                value := row.values[index]
4✔
619
                switch value := value.(type) {
4✔
NEW
620
                case int64:
×
NEW
621
                        return sql.NullFloat64{Float64: float64(value), Valid: true}
×
NEW
622
                case float64:
×
NEW
623
                        return sql.NullFloat64{Float64: value, Valid: true}
×
NEW
624
                case bool:
×
NEW
625
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not float64", value))
×
NEW
626
                case []byte:
×
NEW
627
                        // Special case: go-mysql-driver returns everything as []byte.
×
NEW
628
                        n, err := strconv.ParseFloat(string(value), 64)
×
NEW
629
                        if err != nil {
×
NEW
630
                                panic(fmt.Errorf(callsite(1)+"%d is []byte, not float64", value))
×
631
                        }
NEW
632
                        return sql.NullFloat64{Float64: n, Valid: true}
×
NEW
633
                case string:
×
NEW
634
                        panic(fmt.Errorf(callsite(1)+"%q is string, not float64", value))
×
NEW
635
                case time.Time:
×
NEW
636
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not float64", value))
×
637
                case nil:
4✔
638
                        return sql.NullFloat64{}
4✔
NEW
639
                default:
×
NEW
640
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not float64", value))
×
641
                }
642
        }
643
        return row.NullFloat64Field(Expr(format, values...))
×
644
}
645

646
// NullFloat64Field returns the sql.NullFloat64 value of the field.
647
func (row *Row) NullFloat64Field(field Number) sql.NullFloat64 {
16✔
648
        if row.queryIsStatic {
16✔
NEW
649
                panic(fmt.Errorf(callsite(1) + "cannot call NullFloat64Field for static queries"))
×
650
        }
651
        if row.sqlRows == nil {
20✔
652
                row.fields = append(row.fields, field)
4✔
653
                row.scanDest = append(row.scanDest, &sql.NullFloat64{})
4✔
654
                return sql.NullFloat64{}
4✔
655
        }
4✔
656
        defer func() {
24✔
657
                row.runningIndex++
12✔
658
        }()
12✔
659
        scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64)
12✔
660
        return *scanDest
12✔
661
}
662

663
// Int returns the int value of the expression.
664
func (row *Row) Int(format string, values ...any) int {
12✔
665
        if row.queryIsStatic {
24✔
666
                index, ok := row.columnIndex[format]
12✔
667
                if !ok {
12✔
NEW
668
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
669
                }
670
                value := row.values[index]
12✔
671
                switch value := value.(type) {
12✔
672
                case int64:
9✔
673
                        return int(value)
9✔
NEW
674
                case float64:
×
NEW
675
                        return int(value)
×
NEW
676
                case bool:
×
NEW
677
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not int", value))
×
678
                case []byte:
3✔
679
                        // Special case: go-mysql-driver returns everything as []byte.
3✔
680
                        n, err := strconv.Atoi(string(value))
3✔
681
                        if err != nil {
3✔
NEW
682
                                panic(fmt.Errorf(callsite(1)+"%d is []byte, not int", value))
×
683
                        }
684
                        return n
3✔
NEW
685
                case string:
×
NEW
686
                        panic(fmt.Errorf(callsite(1)+"%q is string, not int", value))
×
NEW
687
                case time.Time:
×
NEW
688
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int", value))
×
NEW
689
                case nil:
×
NEW
690
                        return 0
×
NEW
691
                default:
×
NEW
692
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int", value))
×
693
                }
694
        }
695
        return int(row.NullInt64Field(Expr(format, values...)).Int64)
×
696
}
697

698
// IntField returns the int value of the field.
699
func (row *Row) IntField(field Number) int {
×
NEW
700
        if row.queryIsStatic {
×
NEW
701
                panic(fmt.Errorf(callsite(1) + "cannot call IntField for static queries"))
×
702
        }
703
        return int(row.NullInt64Field(field).Int64)
×
704
}
705

706
// Int64 returns the int64 value of the expression.
707
func (row *Row) Int64(format string, values ...any) int64 {
12✔
708
        if row.queryIsStatic {
24✔
709
                index, ok := row.columnIndex[format]
12✔
710
                if !ok {
12✔
NEW
711
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
712
                }
713
                value := row.values[index]
12✔
714
                switch value := value.(type) {
12✔
715
                case int64:
9✔
716
                        return int64(value)
9✔
NEW
717
                case float64:
×
NEW
718
                        return int64(value)
×
NEW
719
                case bool:
×
NEW
720
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not int64", value))
×
721
                case []byte:
3✔
722
                        // Special case: go-mysql-driver returns everything as []byte.
3✔
723
                        n, err := strconv.ParseInt(string(value), 10, 64)
3✔
724
                        if err != nil {
3✔
NEW
725
                                panic(fmt.Errorf(callsite(1)+"%d is []byte, not int64", value))
×
726
                        }
727
                        return n
3✔
NEW
728
                case string:
×
NEW
729
                        panic(fmt.Errorf(callsite(1)+"%q is string, not int64", value))
×
NEW
730
                case time.Time:
×
NEW
731
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int64", value))
×
NEW
732
                case nil:
×
NEW
733
                        return 0
×
NEW
734
                default:
×
NEW
735
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int64", value))
×
736
                }
737
        }
738
        return row.NullInt64Field(Expr(format, values...)).Int64
×
739
}
740

741
// Int64Field returns the int64 value of the field.
742
func (row *Row) Int64Field(field Number) int64 {
16✔
743
        if row.queryIsStatic {
16✔
NEW
744
                panic(fmt.Errorf(callsite(1) + "cannot call Int64Field for static queries"))
×
745
        }
746
        return row.NullInt64Field(field).Int64
16✔
747
}
748

749
// NullInt64 returns the sql.NullInt64 value of the expression.
750
func (row *Row) NullInt64(format string, values ...any) sql.NullInt64 {
4✔
751
        if row.queryIsStatic {
8✔
752
                index, ok := row.columnIndex[format]
4✔
753
                if !ok {
4✔
NEW
754
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
755
                }
756
                value := row.values[index]
4✔
757
                switch value := value.(type) {
4✔
NEW
758
                case int64:
×
NEW
759
                        return sql.NullInt64{Int64: value, Valid: true}
×
NEW
760
                case float64:
×
NEW
761
                        return sql.NullInt64{Int64: int64(value), Valid: true}
×
NEW
762
                case bool:
×
NEW
763
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not int64", value))
×
NEW
764
                case []byte:
×
NEW
765
                        // Special case: go-mysql-driver returns everything as []byte.
×
NEW
766
                        n, err := strconv.ParseInt(string(value), 10, 64)
×
NEW
767
                        if err != nil {
×
NEW
768
                                panic(fmt.Errorf(callsite(1)+"%d is []byte, not int64", value))
×
769
                        }
NEW
770
                        return sql.NullInt64{Int64: n, Valid: true}
×
NEW
771
                case string:
×
NEW
772
                        panic(fmt.Errorf(callsite(1)+"%q is string, not int64", value))
×
NEW
773
                case time.Time:
×
NEW
774
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int64", value))
×
775
                case nil:
4✔
776
                        return sql.NullInt64{}
4✔
NEW
777
                default:
×
NEW
778
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int64", value))
×
779
                }
780
        }
781
        return row.NullInt64Field(Expr(format, values...))
×
782
}
783

784
// NullInt64Field returns the sql.NullInt64 value of the field.
785
func (row *Row) NullInt64Field(field Number) sql.NullInt64 {
16✔
786
        if row.queryIsStatic {
16✔
NEW
787
                panic(fmt.Errorf(callsite(1) + "cannot call NullInt64Field for static queries"))
×
788
        }
789
        if row.sqlRows == nil {
20✔
790
                row.fields = append(row.fields, field)
4✔
791
                row.scanDest = append(row.scanDest, &sql.NullInt64{})
4✔
792
                return sql.NullInt64{}
4✔
793
        }
4✔
794
        defer func() {
24✔
795
                row.runningIndex++
12✔
796
        }()
12✔
797
        scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64)
12✔
798
        return *scanDest
12✔
799
}
800

801
// JSON scans the JSON expression into destPtr.
802
func (row *Row) JSON(destPtr any, format string, values ...any) {
×
NEW
803
        if row.queryIsStatic {
×
NEW
804
                panic(fmt.Errorf(callsite(1) + "cannot call JSON for static queries"))
×
805
        }
806
        row.json(destPtr, Expr(format, values...), 1)
×
807
}
808

809
// JSONField scans the JSON field into destPtr.
810
func (row *Row) JSONField(destPtr any, field JSON) {
16✔
811
        if row.queryIsStatic {
16✔
NEW
812
                panic(fmt.Errorf(callsite(1) + "cannot call JSONField for static queries"))
×
813
        }
814
        row.json(destPtr, field, 1)
16✔
815
}
816

817
func (row *Row) json(destPtr any, field JSON, skip int) {
16✔
818
        if row.sqlRows == nil {
20✔
819
                if reflect.TypeOf(destPtr).Kind() != reflect.Ptr {
4✔
820
                        panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr))
×
821
                }
822
                row.fields = append(row.fields, field)
4✔
823
                row.scanDest = append(row.scanDest, &nullBytes{
4✔
824
                        dialect:     row.dialect,
4✔
825
                        displayType: displayTypeString,
4✔
826
                })
4✔
827
                return
4✔
828
        }
829
        defer func() {
24✔
830
                row.runningIndex++
12✔
831
        }()
12✔
832
        scanDest := row.scanDest[row.runningIndex].(*nullBytes)
12✔
833
        if scanDest.valid {
24✔
834
                err := json.Unmarshal(scanDest.bytes, destPtr)
12✔
835
                if err != nil {
12✔
836
                        _, file, line, _ := runtime.Caller(skip + 1)
×
837
                        panic(fmt.Errorf(callsite(skip+1)+"unmarshaling json %q into %T: %w", file, line, string(scanDest.bytes), destPtr, err))
×
838
                }
839
        }
840
}
841

842
// String returns the string value of the expression.
843
func (row *Row) String(format string, values ...any) string {
12✔
844
        if row.queryIsStatic {
24✔
845
                index, ok := row.columnIndex[format]
12✔
846
                if !ok {
12✔
NEW
847
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
848
                }
849
                value := row.values[index]
12✔
850
                switch value := value.(type) {
12✔
NEW
851
                case int64:
×
NEW
852
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not string", value))
×
NEW
853
                case float64:
×
NEW
854
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not string", value))
×
NEW
855
                case bool:
×
NEW
856
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not string", value))
×
857
                case []byte:
3✔
858
                        return string(value)
3✔
859
                case string:
9✔
860
                        return value
9✔
NEW
861
                case time.Time:
×
NEW
862
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not string", value))
×
NEW
863
                case nil:
×
NEW
864
                        return ""
×
NEW
865
                default:
×
NEW
866
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not string", value))
×
867
                }
868
        }
869
        return row.NullStringField(Expr(format, values...)).String
×
870
}
871

872
// String returns the string value of the field.
873
func (row *Row) StringField(field String) string {
72✔
874
        if row.queryIsStatic {
72✔
NEW
875
                panic(fmt.Errorf(callsite(1) + "cannot call StringField for static queries"))
×
876
        }
877
        return row.NullStringField(field).String
72✔
878
}
879

880
// NullString returns the sql.NullString value of the expression.
881
func (row *Row) NullString(format string, values ...any) sql.NullString {
4✔
882
        if row.queryIsStatic {
8✔
883
                index, ok := row.columnIndex[format]
4✔
884
                if !ok {
4✔
NEW
885
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
886
                }
887
                value := row.values[index]
4✔
888
                switch value := value.(type) {
4✔
NEW
889
                case int64:
×
NEW
890
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not string", value))
×
NEW
891
                case float64:
×
NEW
892
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not string", value))
×
NEW
893
                case bool:
×
NEW
894
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not string", value))
×
NEW
895
                case []byte:
×
NEW
896
                        return sql.NullString{String: string(value), Valid: true}
×
NEW
897
                case string:
×
NEW
898
                        return sql.NullString{String: value, Valid: true}
×
NEW
899
                case time.Time:
×
NEW
900
                        panic(fmt.Errorf(callsite(1)+"%v is time.Time, not string", value))
×
901
                case nil:
4✔
902
                        return sql.NullString{}
4✔
NEW
903
                default:
×
NEW
904
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not string", value))
×
905
                }
906
        }
907
        return row.NullStringField(Expr(format, values...))
×
908
}
909

910
// NullStringField returns the sql.NullString value of the field.
911
func (row *Row) NullStringField(field String) sql.NullString {
72✔
912
        if row.queryIsStatic {
72✔
NEW
913
                panic(fmt.Errorf(callsite(1) + "cannot call NullStringField for static queries"))
×
914
        }
915
        if row.sqlRows == nil {
96✔
916
                row.fields = append(row.fields, field)
24✔
917
                row.scanDest = append(row.scanDest, &sql.NullString{})
24✔
918
                return sql.NullString{}
24✔
919
        }
24✔
920
        defer func() {
96✔
921
                row.runningIndex++
48✔
922
        }()
48✔
923
        scanDest := row.scanDest[row.runningIndex].(*sql.NullString)
48✔
924
        return *scanDest
48✔
925
}
926

927
// https://github.com/mattn/go-sqlite3/blob/4396a38886da660e403409e35ef4a37906bf0975/sqlite3.go#L209
928
var sqliteTimestampFormats = []string{
929
        "2006-01-02 15:04:05.999999999-07:00",
930
        "2006-01-02T15:04:05.999999999-07:00",
931
        "2006-01-02 15:04:05.999999999",
932
        "2006-01-02T15:04:05.999999999",
933
        "2006-01-02 15:04:05",
934
        "2006-01-02T15:04:05",
935
        "2006-01-02 15:04",
936
        "2006-01-02T15:04",
937
        "2006-01-02",
938
}
939

940
// Time returns the time.Time value of the expression.
941
func (row *Row) Time(format string, values ...any) time.Time {
12✔
942
        if row.queryIsStatic {
24✔
943
                index, ok := row.columnIndex[format]
12✔
944
                if !ok {
12✔
NEW
945
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
946
                }
947
                value := row.values[index]
12✔
948
                switch value := value.(type) {
12✔
NEW
949
                case int64:
×
NEW
950
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not time.Time", value))
×
NEW
951
                case float64:
×
NEW
952
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not time.Time", value))
×
NEW
953
                case bool:
×
NEW
954
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not time.Time", value))
×
NEW
955
                case []byte:
×
NEW
956
                        // Special case: go-mysql-driver returns everything as []byte.
×
NEW
957
                        s := strings.TrimSuffix(string(value), "Z")
×
NEW
958
                        for _, format := range sqliteTimestampFormats {
×
NEW
959
                                if t, err := time.ParseInLocation(format, s, time.UTC); err == nil {
×
NEW
960
                                        return t
×
NEW
961
                                }
×
962
                        }
NEW
963
                        panic(fmt.Errorf(callsite(1)+"%d is []byte, not time.Time", value))
×
NEW
964
                case string:
×
NEW
965
                        panic(fmt.Errorf(callsite(1)+"%q is string, not time.Time", value))
×
966
                case time.Time:
12✔
967
                        return value
12✔
NEW
968
                case nil:
×
NEW
969
                        return time.Time{}
×
NEW
970
                default:
×
NEW
971
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not time.Time", value))
×
972
                }
973
        }
974
        return row.NullTimeField(Expr(format, values...)).Time
×
975
}
976

977
// Time returns the time.Time value of the field.
978
func (row *Row) TimeField(field Time) time.Time {
16✔
979
        if row.queryIsStatic {
16✔
NEW
980
                panic(fmt.Errorf(callsite(1) + "cannot call TimeField for static queries"))
×
981
        }
982
        return row.NullTimeField(field).Time
16✔
983
}
984

985
// NullTime returns the sql.NullTime value of the expression.
986
func (row *Row) NullTime(format string, values ...any) sql.NullTime {
4✔
987
        if row.queryIsStatic {
8✔
988
                index, ok := row.columnIndex[format]
4✔
989
                if !ok {
4✔
NEW
990
                        panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", ")))
×
991
                }
992
                value := row.values[index]
4✔
993
                switch value := value.(type) {
4✔
NEW
994
                case int64:
×
NEW
995
                        panic(fmt.Errorf(callsite(1)+"%d is int64, not time.Time", value))
×
NEW
996
                case float64:
×
NEW
997
                        panic(fmt.Errorf(callsite(1)+"%d is float64, not time.Time", value))
×
NEW
998
                case bool:
×
NEW
999
                        panic(fmt.Errorf(callsite(1)+"%v is bool, not time.Time", value))
×
NEW
1000
                case []byte:
×
NEW
1001
                        // Special case: go-mysql-driver returns everything as []byte.
×
NEW
1002
                        s := strings.TrimSuffix(string(value), "Z")
×
NEW
1003
                        for _, format := range sqliteTimestampFormats {
×
NEW
1004
                                if t, err := time.ParseInLocation(format, s, time.UTC); err == nil {
×
NEW
1005
                                        return sql.NullTime{Time: t, Valid: true}
×
NEW
1006
                                }
×
1007
                        }
NEW
1008
                        panic(fmt.Errorf(callsite(1)+"%d is []byte, not time.Time", value))
×
NEW
1009
                case string:
×
NEW
1010
                        panic(fmt.Errorf(callsite(1)+"%q is string, not time.Time", value))
×
NEW
1011
                case time.Time:
×
NEW
1012
                        return sql.NullTime{Time: value, Valid: true}
×
1013
                case nil:
4✔
1014
                        return sql.NullTime{}
4✔
NEW
1015
                default:
×
NEW
1016
                        panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not time.Time", value))
×
1017
                }
1018
        }
1019
        return row.NullTimeField(Expr(format, values...))
×
1020
}
1021

1022
// NullTimeField returns the sql.NullTime value of the field.
1023
func (row *Row) NullTimeField(field Time) sql.NullTime {
16✔
1024
        if row.queryIsStatic {
16✔
NEW
1025
                panic(fmt.Errorf(callsite(1) + "cannot call NullTimeField for static queries"))
×
1026
        }
1027
        if row.sqlRows == nil {
20✔
1028
                row.fields = append(row.fields, field)
4✔
1029
                row.scanDest = append(row.scanDest, &sql.NullTime{})
4✔
1030
                return sql.NullTime{}
4✔
1031
        }
4✔
1032
        defer func() {
24✔
1033
                row.runningIndex++
12✔
1034
        }()
12✔
1035
        scanDest := row.scanDest[row.runningIndex].(*sql.NullTime)
12✔
1036
        return *scanDest
12✔
1037
}
1038

1039
// UUID scans the UUID expression into destPtr.
1040
func (row *Row) UUID(destPtr any, format string, values ...any) {
×
NEW
1041
        if row.queryIsStatic {
×
NEW
1042
                panic(fmt.Errorf(callsite(1) + "cannot call UUID for static queries"))
×
1043
        }
1044
        row.uuid(destPtr, Expr(format, values...), 1)
×
1045
}
1046

1047
// UUIDField scans the UUID field into destPtr.
1048
func (row *Row) UUIDField(destPtr any, field UUID) {
16✔
1049
        if row.queryIsStatic {
16✔
NEW
1050
                panic(fmt.Errorf(callsite(1) + "cannot call UUIDField for static queries"))
×
1051
        }
1052
        row.uuid(destPtr, field, 1)
16✔
1053
}
1054

1055
func (row *Row) uuid(destPtr any, field UUID, skip int) {
16✔
1056
        if row.sqlRows == nil {
20✔
1057
                if _, ok := destPtr.(*[16]byte); !ok {
8✔
1058
                        if reflect.TypeOf(destPtr).Kind() != reflect.Ptr {
4✔
1059
                                panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr))
×
1060
                        }
1061
                        destValue := reflect.ValueOf(destPtr).Elem()
4✔
1062
                        if destValue.Kind() != reflect.Array || destValue.Len() != 16 || destValue.Type().Elem().Kind() != reflect.Uint8 {
4✔
1063
                                panic(fmt.Errorf(callsite(skip+1)+"%T is not a pointer to a [16]byte", destPtr))
×
1064
                        }
1065
                }
1066
                row.fields = append(row.fields, field)
4✔
1067
                row.scanDest = append(row.scanDest, &nullBytes{
4✔
1068
                        dialect:     row.dialect,
4✔
1069
                        displayType: displayTypeUUID,
4✔
1070
                })
4✔
1071
                return
4✔
1072
        }
1073
        defer func() {
24✔
1074
                row.runningIndex++
12✔
1075
        }()
12✔
1076
        scanDest := row.scanDest[row.runningIndex].(*nullBytes)
12✔
1077
        var err error
12✔
1078
        var uuid [16]byte
12✔
1079
        if len(scanDest.bytes) == 16 {
21✔
1080
                copy(uuid[:], scanDest.bytes)
9✔
1081
        } else {
12✔
1082
                uuid, err = googleuuid.ParseBytes(scanDest.bytes)
3✔
1083
                if err != nil {
3✔
1084
                        panic(fmt.Errorf(callsite(skip+1)+"parsing %q as UUID string: %w", string(scanDest.bytes), err))
×
1085
                }
1086
        }
1087
        if destArrayPtr, ok := destPtr.(*[16]byte); ok {
12✔
1088
                copy((*destArrayPtr)[:], uuid[:])
×
1089
                return
×
1090
        }
×
1091
        destValue := reflect.ValueOf(destPtr).Elem()
12✔
1092
        for i := 0; i < 16; i++ {
204✔
1093
                destValue.Index(i).Set(reflect.ValueOf(uuid[i]))
192✔
1094
        }
192✔
1095
}
1096

1097
// Column keeps track of what the values mapped to what Field in an
1098
// InsertQuery or SelectQuery.
1099
type Column struct {
1100
        dialect string
1101
        // determines if UPDATE or INSERT
1102
        isUpdate bool
1103
        // UPDATE
1104
        assignments Assignments
1105
        // INSERT
1106
        rowStarted    bool
1107
        rowEnded      bool
1108
        firstField    string
1109
        insertColumns Fields
1110
        rowValues     RowValues
1111
}
1112

1113
// Set maps the value to the Field.
1114
func (col *Column) Set(field Field, value any) {
466✔
1115
        if field == nil {
466✔
1116
                panic(fmt.Errorf(callsite(1) + "setting a nil field"))
×
1117
        }
1118
        // UPDATE mode
1119
        if col.isUpdate {
531✔
1120
                col.assignments = append(col.assignments, Set(field, value))
65✔
1121
                return
65✔
1122
        }
65✔
1123
        // INSERT mode
1124
        name := toString(col.dialect, field)
401✔
1125
        if name == "" {
401✔
1126
                panic(fmt.Errorf(callsite(1) + "field name is empty"))
×
1127
        }
1128
        if !col.rowStarted {
431✔
1129
                col.rowStarted = true
30✔
1130
                col.firstField = name
30✔
1131
                col.insertColumns = append(col.insertColumns, field)
30✔
1132
                col.rowValues = append(col.rowValues, RowValue{value})
30✔
1133
                return
30✔
1134
        }
30✔
1135
        if col.rowStarted && name == col.firstField {
399✔
1136
                if !col.rowEnded {
41✔
1137
                        col.rowEnded = true
13✔
1138
                }
13✔
1139
                // Start a new RowValue
1140
                col.rowValues = append(col.rowValues, RowValue{value})
28✔
1141
                return
28✔
1142
        }
1143
        if !col.rowEnded {
474✔
1144
                col.insertColumns = append(col.insertColumns, field)
131✔
1145
        }
131✔
1146
        // Append to last RowValue
1147
        last := len(col.rowValues) - 1
343✔
1148
        col.rowValues[last] = append(col.rowValues[last], value)
343✔
1149
}
1150

1151
// SetBytes maps the []byte value to the field.
1152
func (col *Column) SetBytes(field Binary, value []byte) { col.Set(field, value) }
12✔
1153

1154
// SetBool maps the bool value to the field.
1155
func (col *Column) SetBool(field Boolean, value bool) { col.Set(field, value) }
12✔
1156

1157
// SetFloat64 maps the float64 value to the field.
1158
func (col *Column) SetFloat64(field Number, value float64) { col.Set(field, value) }
12✔
1159

1160
// SetInt maps the int value to the field.
1161
func (col *Column) SetInt(field Number, value int) { col.Set(field, value) }
5✔
1162

1163
// SetInt64 maps the int64 value to the field.
1164
func (col *Column) SetInt64(field Number, value int64) { col.Set(field, value) }
12✔
1165

1166
// SetString maps the string value to the field.
1167
func (col *Column) SetString(field String, value string) { col.Set(field, value) }
46✔
1168

1169
// SetTime maps the time.Time value to the field.
1170
func (col *Column) SetTime(field Time, value time.Time) { col.Set(field, value) }
17✔
1171

1172
// SetArray maps the array value to the field. The value should be []string,
1173
// []int, []int64, []int32, []float64, []float32 or []bool.
1174
func (col *Column) SetArray(field Array, value any) { col.Set(field, ArrayValue(value)) }
84✔
1175

1176
// SetEnum maps the enum value to the field.
1177
func (col *Column) SetEnum(field Enum, value Enumeration) { col.Set(field, EnumValue(value)) }
36✔
1178

1179
// SetJSON maps the JSON value to the field. The value should be able to be
1180
// convertible to JSON using json.Marshal.
1181
func (col *Column) SetJSON(field JSON, value any) { col.Set(field, JSONValue(value)) }
12✔
1182

1183
// SetUUID maps the UUID value to the field. The value's type or underlying
1184
// type should be [16]byte.
1185
func (col *Column) SetUUID(field UUID, value any) { col.Set(field, UUIDValue(value)) }
12✔
1186

1187
func callsite(skip int) string {
×
1188
        _, file, line, ok := runtime.Caller(skip + 1)
×
1189
        if !ok {
×
1190
                return ""
×
1191
        }
×
1192
        return filepath.Base(file) + ":" + strconv.Itoa(line) + ": "
×
1193
}
1194

1195
type displayType int8
1196

1197
const (
1198
        displayTypeBinary displayType = iota
1199
        displayTypeString
1200
        displayTypeUUID
1201
)
1202

1203
// nullBytes is used in place of scanning into *[]byte. We use *nullBytes
1204
// instead of *[]byte because of the displayType field, which determines how to
1205
// render the value to the user. This is important for logging the query
1206
// results, because UUIDs/JSON/Arrays are all scanned into bytes but we don't
1207
// want to display them as bytes (we need to convert them to UUID/JSON/Array
1208
// strings instead).
1209
type nullBytes struct {
1210
        bytes       []byte
1211
        dialect     string
1212
        displayType displayType
1213
        valid       bool
1214
}
1215

1216
func (n *nullBytes) Scan(value any) error {
120✔
1217
        if value == nil {
123✔
1218
                n.bytes, n.valid = nil, false
3✔
1219
                return nil
3✔
1220
        }
3✔
1221
        n.valid = true
117✔
1222
        switch value := value.(type) {
117✔
1223
        case string:
48✔
1224
                n.bytes = []byte(value)
48✔
1225
        case []byte:
69✔
1226
                n.bytes = value
69✔
1227
        default:
×
1228
                return fmt.Errorf("unable to convert %#v to bytes", value)
×
1229
        }
1230
        return nil
117✔
1231
}
1232

1233
func (n *nullBytes) Value() (driver.Value, error) {
120✔
1234
        if !n.valid {
123✔
1235
                return nil, nil
3✔
1236
        }
3✔
1237
        switch n.displayType {
117✔
1238
        case displayTypeString:
93✔
1239
                return string(n.bytes), nil
93✔
1240
        case displayTypeUUID:
12✔
1241
                if n.dialect != "postgres" {
21✔
1242
                        return n.bytes, nil
9✔
1243
                }
9✔
1244
                var uuid [16]byte
3✔
1245
                var buf [36]byte
3✔
1246
                copy(uuid[:], n.bytes)
3✔
1247
                googleuuid.EncodeHex(buf[:], uuid)
3✔
1248
                return string(buf[:]), nil
3✔
1249
        default:
12✔
1250
                return n.bytes, nil
12✔
1251
        }
1252
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc