• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bokwoon95 / sq / 9053178445

12 May 2024 05:43PM UTC coverage: 82.269% (-2.7%) from 85.007%
9053178445

push

github

bokwoon95
add support for static queries (raw SQL)

249 of 544 new or added lines in 2 files covered. (45.77%)

4 existing lines in 1 file now uncovered.

5591 of 6796 relevant lines covered (82.27%)

46.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.69
/fetch_exec.go
1
package sq
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "fmt"
8
        "reflect"
9
        "runtime"
10
        "strconv"
11
        "strings"
12
        "sync/atomic"
13
        "time"
14
)
15

16
// Default dialect used by all queries (if no dialect is explicitly provided).
17
var DefaultDialect atomic.Pointer[string]
18

19
// A Cursor represents a database cursor.
20
type Cursor[T any] struct {
21
        ctx           context.Context
22
        row           *Row
23
        rowmapper     func(*Row) T
24
        queryStats    QueryStats
25
        logSettings   LogSettings
26
        logger        SqLogger
27
        logged        int32
28
        fieldNames    []string
29
        resultsBuffer *bytes.Buffer
30
}
31

32
// FetchCursor returns a new cursor.
33
func FetchCursor[T any](db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) {
×
34
        return fetchCursor(context.Background(), db, query, rowmapper, 1)
×
35
}
×
36

37
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
38
func FetchCursorContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) {
×
39
        return fetchCursor(ctx, db, query, rowmapper, 1)
×
40
}
×
41

42
func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T, skip int) (cursor *Cursor[T], err error) {
36✔
43
        if db == nil {
36✔
44
                return nil, fmt.Errorf("db is nil")
×
45
        }
×
46
        if query == nil {
36✔
47
                return nil, fmt.Errorf("query is nil")
×
48
        }
×
49
        if rowmapper == nil {
36✔
50
                return nil, fmt.Errorf("rowmapper is nil")
×
51
        }
×
52
        dialect := query.GetDialect()
36✔
53
        if dialect == "" {
36✔
54
                defaultDialect := DefaultDialect.Load()
×
55
                if defaultDialect != nil {
×
56
                        dialect = *defaultDialect
×
57
                }
×
58
        }
59
        // If we can't set the fetchable fields, the query is static.
60
        _, ok := query.SetFetchableFields(nil)
36✔
61
        cursor = &Cursor[T]{
36✔
62
                ctx:       ctx,
36✔
63
                rowmapper: rowmapper,
36✔
64
                row: &Row{
36✔
65
                        dialect:       dialect,
36✔
66
                        queryIsStatic: !ok,
36✔
67
                },
36✔
68
                queryStats: QueryStats{
36✔
69
                        Dialect:  dialect,
36✔
70
                        Params:   make(map[string][]int),
36✔
71
                        RowCount: sql.NullInt64{Valid: true},
36✔
72
                },
36✔
73
        }
36✔
74

36✔
75
        // If the query is dynamic, call the rowmapper to populate row.fields and
36✔
76
        // row.scanDest. Then, insert those fields back into the query.
36✔
77
        if !cursor.row.queryIsStatic {
54✔
78
                defer mapperFunctionPanicked(&err)
18✔
79
                _ = cursor.rowmapper(cursor.row)
18✔
80
                query, _ = query.SetFetchableFields(cursor.row.fields)
18✔
81
        }
18✔
82

83
        // Build query.
84
        buf := bufpool.Get().(*bytes.Buffer)
36✔
85
        buf.Reset()
36✔
86
        defer bufpool.Put(buf)
36✔
87
        err = query.WriteSQL(ctx, dialect, buf, &cursor.queryStats.Args, cursor.queryStats.Params)
36✔
88
        cursor.queryStats.Query = buf.String()
36✔
89
        if err != nil {
36✔
90
                return nil, err
×
91
        }
×
92

93
        // Setup logger.
94
        cursor.logger, _ = db.(SqLogger)
36✔
95
        if cursor.logger == nil {
60✔
96
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
24✔
97
                if logQuery != nil {
24✔
98
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
99
                        cursor.logger = &sqLogStruct{
×
100
                                logSettings: logSettings,
×
101
                                logQuery:    logQuery,
×
102
                        }
×
103
                }
×
104
        }
105
        if cursor.logger != nil {
48✔
106
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
12✔
107
                if cursor.logSettings.IncludeCaller {
24✔
108
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
12✔
109
                }
12✔
110
        }
111

112
        // Run query.
113
        if cursor.logSettings.IncludeTime {
48✔
114
                cursor.queryStats.StartedAt = time.Now()
12✔
115
        }
12✔
116
        cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...)
36✔
117
        if cursor.logSettings.IncludeTime {
48✔
118
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
12✔
119
        }
12✔
120
        if cursor.queryStats.Err != nil {
36✔
121
                cursor.log()
×
122
                return nil, cursor.queryStats.Err
×
123
        }
×
124

125
        // If the query is static, we now know the number of columns returned by
126
        // the query and can allocate the values slice and scanDest slice for
127
        // scanning later.
128
        if cursor.row.queryIsStatic {
54✔
129
                cursor.row.columns, err = cursor.row.sqlRows.Columns()
18✔
130
                if err != nil {
18✔
NEW
131
                        return nil, err
×
NEW
132
                }
×
133
                cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
18✔
134
                if err != nil {
18✔
NEW
135
                        return nil, err
×
NEW
136
                }
×
137
                cursor.row.columnIndex = make(map[string]int)
18✔
138
                for index, column := range cursor.row.columns {
122✔
139
                        cursor.row.columnIndex[column] = index
104✔
140
                }
104✔
141
                cursor.row.values = make([]any, len(cursor.row.columns))
18✔
142
                cursor.row.scanDest = make([]any, len(cursor.row.columns))
18✔
143
                for index := range cursor.row.values {
122✔
144
                        cursor.row.scanDest[index] = &cursor.row.values[index]
104✔
145
                }
104✔
146
        }
147

148
        // Allocate the resultsBuffer.
149
        if cursor.logSettings.IncludeResults > 0 {
42✔
150
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
6✔
151
                cursor.resultsBuffer.Reset()
6✔
152
        }
6✔
153
        return cursor, nil
36✔
154
}
155

156
// Next advances the cursor to the next result.
157
func (cursor *Cursor[T]) Next() bool {
122✔
158
        hasNext := cursor.row.sqlRows.Next()
122✔
159
        if hasNext {
222✔
160
                cursor.queryStats.RowCount.Int64++
100✔
161
        } else {
122✔
162
                cursor.log()
22✔
163
        }
22✔
164
        return hasNext
122✔
165
}
166

167
// RowCount returns the current row number so far.
168
func (cursor *Cursor[T]) RowCount() int64 { return cursor.queryStats.RowCount.Int64 }
22✔
169

170
// Result returns the cursor result.
171
func (cursor *Cursor[T]) Result() (result T, err error) {
100✔
172
        err = cursor.row.sqlRows.Scan(cursor.row.scanDest...)
100✔
173
        if err != nil {
100✔
NEW
174
                cursor.log()
×
NEW
175
                fieldMappings := getFieldMappings(cursor.queryStats.Dialect, cursor.row.fields, cursor.row.scanDest)
×
NEW
176
                return result, fmt.Errorf("please check if your mapper function is correct:%s\n%w", fieldMappings, err)
×
UNCOV
177
        }
×
178
        // If results should be logged, write the row into the resultsBuffer.
179
        if cursor.resultsBuffer != nil && cursor.queryStats.RowCount.Int64 <= int64(cursor.logSettings.IncludeResults) {
142✔
180
                if len(cursor.fieldNames) == 0 {
52✔
181
                        cursor.fieldNames = getFieldNames(cursor.ctx, cursor.row)
10✔
182
                }
10✔
183
                cursor.resultsBuffer.WriteString("\n----[ Row " + strconv.FormatInt(cursor.queryStats.RowCount.Int64, 10) + " ]----")
42✔
184
                for i := range cursor.row.scanDest {
378✔
185
                        cursor.resultsBuffer.WriteString("\n")
336✔
186
                        if i < len(cursor.fieldNames) {
672✔
187
                                cursor.resultsBuffer.WriteString(cursor.fieldNames[i])
336✔
188
                        }
336✔
189
                        cursor.resultsBuffer.WriteString(": ")
336✔
190
                        scanDest := cursor.row.scanDest[i]
336✔
191
                        rhs, err := Sprint(cursor.queryStats.Dialect, scanDest)
336✔
192
                        if err != nil {
336✔
193
                                cursor.resultsBuffer.WriteString("%!(error=" + err.Error() + ")")
×
194
                                continue
×
195
                        }
196
                        cursor.resultsBuffer.WriteString(rhs)
336✔
197
                }
198
        }
199
        cursor.row.runningIndex = 0
100✔
200
        defer mapperFunctionPanicked(&err)
100✔
201
        result = cursor.rowmapper(cursor.row)
100✔
202
        return result, nil
100✔
203
}
204

205
func (cursor *Cursor[T]) log() {
110✔
206
        if !atomic.CompareAndSwapInt32(&cursor.logged, 0, 1) {
176✔
207
                return
66✔
208
        }
66✔
209
        if cursor.resultsBuffer != nil {
54✔
210
                cursor.queryStats.Results = cursor.resultsBuffer.String()
10✔
211
                bufpool.Put(cursor.resultsBuffer)
10✔
212
        }
10✔
213
        if cursor.logger == nil {
68✔
214
                return
24✔
215
        }
24✔
216
        if cursor.logSettings.LogAsynchronously {
20✔
217
                go cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats)
×
218
        } else {
20✔
219
                cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats)
20✔
220
        }
20✔
221
}
222

223
// Close closes the cursor.
224
func (cursor *Cursor[T]) Close() error {
88✔
225
        cursor.log()
88✔
226
        if err := cursor.row.sqlRows.Close(); err != nil {
88✔
227
                return err
×
228
        }
×
229
        if err := cursor.row.sqlRows.Err(); err != nil {
88✔
230
                return err
×
231
        }
×
232
        return nil
88✔
233
}
234

235
// FetchOne returns the first result from running the given Query on the given
236
// DB.
237
func FetchOne[T any](db DB, query Query, rowmapper func(*Row) T) (T, error) {
18✔
238
        cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1)
18✔
239
        if err != nil {
18✔
240
                return *new(T), err
×
241
        }
×
242
        defer cursor.Close()
18✔
243
        return cursorResult(cursor)
18✔
244
}
245

246
// FetchOneContext is like FetchOne but additionally requires a context.Context.
247
func FetchOneContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (T, error) {
×
248
        cursor, err := fetchCursor(ctx, db, query, rowmapper, 1)
×
249
        if err != nil {
×
250
                return *new(T), err
×
251
        }
×
252
        defer cursor.Close()
×
253
        return cursorResult(cursor)
×
254
}
255

256
// FetchAll returns all results from running the given Query on the given DB.
257
func FetchAll[T any](db DB, query Query, rowmapper func(*Row) T) ([]T, error) {
18✔
258
        cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1)
18✔
259
        if err != nil {
18✔
UNCOV
260
                return nil, err
×
UNCOV
261
        }
×
262
        defer cursor.Close()
18✔
263
        return cursorResults(cursor)
18✔
264
}
265

266
// FetchAllContext is like FetchAll but additionally requires a context.Context.
267
func FetchAllContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) ([]T, error) {
×
268
        cursor, err := fetchCursor(ctx, db, query, rowmapper, 1)
×
269
        if err != nil {
×
270
                return nil, err
×
271
        }
×
272
        defer cursor.Close()
×
273
        return cursorResults(cursor)
×
274
}
275

276
// CompiledFetch is the result of compiling a Query down into a query string
277
// and args slice. A CompiledFetch can be safely executed in parallel.
278
type CompiledFetch[T any] struct {
279
        dialect   string
280
        query     string
281
        args      []any
282
        params    map[string][]int
283
        rowmapper func(*Row) T
284
        // if queryIsStatic is true, the rowmapper doesn't actually know what
285
        // columns are in the query and it must be determined at runtime after
286
        // running the query.
287
        queryIsStatic bool
288
}
289

290
// NewCompiledFetch returns a new CompiledFetch.
291
func NewCompiledFetch[T any](dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) *CompiledFetch[T] {
4✔
292
        return &CompiledFetch[T]{
4✔
293
                dialect:   dialect,
4✔
294
                query:     query,
4✔
295
                args:      args,
4✔
296
                params:    params,
4✔
297
                rowmapper: rowmapper,
4✔
298
        }
4✔
299
}
4✔
300

301
// CompileFetch returns a new CompileFetch.
302
func CompileFetch[T any](q Query, rowmapper func(*Row) T) (*CompiledFetch[T], error) {
4✔
303
        return CompileFetchContext(context.Background(), q, rowmapper)
4✔
304
}
4✔
305

306
// CompileFetchContext is like CompileFetch but accepts a context.Context.
307
func CompileFetchContext[T any](ctx context.Context, query Query, rowmapper func(*Row) T) (compiledFetch *CompiledFetch[T], err error) {
8✔
308
        if query == nil {
8✔
309
                return nil, fmt.Errorf("query is nil")
×
310
        }
×
311
        if rowmapper == nil {
8✔
312
                return nil, fmt.Errorf("rowmapper is nil")
×
313
        }
×
314
        dialect := query.GetDialect()
8✔
315
        if dialect == "" {
8✔
316
                defaultDialect := DefaultDialect.Load()
×
317
                if defaultDialect != nil {
×
318
                        dialect = *defaultDialect
×
319
                }
×
320
        }
321
        // If we can't set the fetchable fields, the query is static.
322
        _, ok := query.SetFetchableFields(nil)
8✔
323
        compiledFetch = &CompiledFetch[T]{
8✔
324
                dialect:       dialect,
8✔
325
                params:        make(map[string][]int),
8✔
326
                rowmapper:     rowmapper,
8✔
327
                queryIsStatic: !ok,
8✔
328
        }
8✔
329
        row := &Row{
8✔
330
                dialect:       dialect,
8✔
331
                queryIsStatic: !ok,
8✔
332
        }
8✔
333

8✔
334
        // If the query is dynamic, call the rowmapper to populate row.fields.
8✔
335
        // Then, insert those fields back into the query.
8✔
336
        if !row.queryIsStatic {
12✔
337
                defer mapperFunctionPanicked(&err)
4✔
338
                _ = rowmapper(row)
4✔
339
                query, _ = query.SetFetchableFields(row.fields)
4✔
340
        }
4✔
341

342
        // Build query.
343
        buf := bufpool.Get().(*bytes.Buffer)
8✔
344
        buf.Reset()
8✔
345
        defer bufpool.Put(buf)
8✔
346
        err = query.WriteSQL(ctx, dialect, buf, &compiledFetch.args, compiledFetch.params)
8✔
347
        compiledFetch.query = buf.String()
8✔
348
        if err != nil {
8✔
349
                return nil, err
×
350
        }
×
351
        return compiledFetch, nil
8✔
352
}
353

354
// FetchCursor returns a new cursor.
355
func (compiledFetch *CompiledFetch[T]) FetchCursor(db DB, params Params) (*Cursor[T], error) {
×
356
        return compiledFetch.fetchCursor(context.Background(), db, params, 1)
×
357
}
×
358

359
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
360
func (compiledFetch *CompiledFetch[T]) FetchCursorContext(ctx context.Context, db DB, params Params) (*Cursor[T], error) {
×
361
        return compiledFetch.fetchCursor(ctx, db, params, 1)
×
362
}
×
363

364
func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, params Params, skip int) (cursor *Cursor[T], err error) {
4✔
365
        if db == nil {
4✔
366
                return nil, fmt.Errorf("db is nil")
×
367
        }
×
368
        cursor = &Cursor[T]{
4✔
369
                ctx:       ctx,
4✔
370
                rowmapper: compiledFetch.rowmapper,
4✔
371
                row: &Row{
4✔
372
                        dialect:       compiledFetch.dialect,
4✔
373
                        queryIsStatic: compiledFetch.queryIsStatic,
4✔
374
                },
4✔
375
                queryStats: QueryStats{
4✔
376
                        Dialect: compiledFetch.dialect,
4✔
377
                        Query:   compiledFetch.query,
4✔
378
                        Args:    compiledFetch.args,
4✔
379
                        Params:  compiledFetch.params,
4✔
380
                },
4✔
381
        }
4✔
382

4✔
383
        // Call the rowmapper to populate row.scanDest.
4✔
384
        if !cursor.row.queryIsStatic {
6✔
385
                defer mapperFunctionPanicked(&err)
2✔
386
                _ = cursor.rowmapper(cursor.row)
2✔
387
        }
2✔
388

389
        // Substitute params.
390
        cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params)
4✔
391
        if err != nil {
4✔
392
                return nil, err
×
393
        }
×
394

395
        // Setup logger.
396
        cursor.queryStats.RowCount.Valid = true
4✔
397
        cursor.logger, _ = db.(SqLogger)
4✔
398
        if cursor.logger == nil {
4✔
399
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
400
                if logQuery != nil {
×
401
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
402
                        cursor.logger = &sqLogStruct{
×
403
                                logSettings: logSettings,
×
404
                                logQuery:    logQuery,
×
405
                        }
×
406
                }
×
407
        }
408
        if cursor.logger != nil {
8✔
409
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
4✔
410
                if cursor.logSettings.IncludeCaller {
8✔
411
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
4✔
412
                }
4✔
413
        }
414

415
        // Run query.
416
        if cursor.logSettings.IncludeTime {
8✔
417
                cursor.queryStats.StartedAt = time.Now()
4✔
418
        }
4✔
419
        cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...)
4✔
420
        if cursor.logSettings.IncludeTime {
8✔
421
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
4✔
422
        }
4✔
423
        if cursor.queryStats.Err != nil {
4✔
424
                return nil, cursor.queryStats.Err
×
425
        }
×
426

427
        // If the query is static, we now know the number of columns returned by
428
        // the query and can allocate the values slice and scanDest slice for
429
        // scanning later.
430
        if cursor.row.queryIsStatic {
6✔
431
                cursor.row.columns, err = cursor.row.sqlRows.Columns()
2✔
432
                if err != nil {
2✔
NEW
433
                        return nil, err
×
NEW
434
                }
×
435
                cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
2✔
436
                if err != nil {
2✔
NEW
437
                        return nil, err
×
NEW
438
                }
×
439
                cursor.row.columnIndex = make(map[string]int)
2✔
440
                for index, column := range cursor.row.columns {
10✔
441
                        cursor.row.columnIndex[column] = index
8✔
442
                }
8✔
443
                cursor.row.values = make([]any, len(cursor.row.columns))
2✔
444
                cursor.row.scanDest = make([]any, len(cursor.row.columns))
2✔
445
                for index := range cursor.row.values {
10✔
446
                        cursor.row.scanDest[index] = &cursor.row.values[index]
8✔
447
                }
8✔
448
        }
449

450
        // Allocate the resultsBuffer.
451
        if cursor.logSettings.IncludeResults > 0 {
6✔
452
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
2✔
453
                cursor.resultsBuffer.Reset()
2✔
454
        }
2✔
455
        return cursor, nil
4✔
456
}
457

458
// FetchOne returns the first result from running the CompiledFetch on the
459
// given DB with the give params.
460
func (compiledFetch *CompiledFetch[T]) FetchOne(db DB, params Params) (T, error) {
2✔
461
        cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1)
2✔
462
        if err != nil {
2✔
463
                return *new(T), err
×
464
        }
×
465
        defer cursor.Close()
2✔
466
        return cursorResult(cursor)
2✔
467
}
468

469
// FetchOneContext is like FetchOne but additionally requires a context.Context.
470
func (compiledFetch *CompiledFetch[T]) FetchOneContext(ctx context.Context, db DB, params Params) (T, error) {
×
471
        cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1)
×
472
        if err != nil {
×
473
                return *new(T), err
×
474
        }
×
475
        defer cursor.Close()
×
476
        return cursorResult(cursor)
×
477
}
478

479
// FetchAll returns all the results from running the CompiledFetch on the given
480
// DB with the give params.
481
func (compiledFetch *CompiledFetch[T]) FetchAll(db DB, params Params) ([]T, error) {
2✔
482
        cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1)
2✔
483
        if err != nil {
2✔
484
                return nil, err
×
485
        }
×
486
        defer cursor.Close()
2✔
487
        return cursorResults(cursor)
2✔
488
}
489

490
// FetchAllContext is like FetchAll but additionally requires a context.Context.
491
func (compiledFetch *CompiledFetch[T]) FetchAllContext(ctx context.Context, db DB, params Params) ([]T, error) {
×
492
        cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1)
×
493
        if err != nil {
×
494
                return nil, err
×
495
        }
×
496
        defer cursor.Close()
×
497
        return cursorResults(cursor)
×
498
}
499

500
// GetSQL returns a copy of the dialect, query, args, params and rowmapper that
501
// make up the CompiledFetch.
502
func (compiledFetch *CompiledFetch[T]) GetSQL() (dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) {
4✔
503
        dialect = compiledFetch.dialect
4✔
504
        query = compiledFetch.query
4✔
505
        args = make([]any, len(compiledFetch.args))
4✔
506
        params = make(map[string][]int)
4✔
507
        copy(args, compiledFetch.args)
4✔
508
        for name, indexes := range compiledFetch.params {
6✔
509
                indexes2 := make([]int, len(indexes))
2✔
510
                copy(indexes2, indexes)
2✔
511
                params[name] = indexes2
2✔
512
        }
2✔
513
        return dialect, query, args, params, compiledFetch.rowmapper
4✔
514
}
515

516
// Prepare creates a PreparedFetch from a CompiledFetch by preparing it on
517
// the given DB.
518
func (compiledFetch *CompiledFetch[T]) Prepare(db DB) (*PreparedFetch[T], error) {
×
519
        return compiledFetch.PrepareContext(context.Background(), db)
×
520
}
×
521

522
// PrepareContext is like Prepare but additionally requires a context.Context.
523
func (compiledFetch *CompiledFetch[T]) PrepareContext(ctx context.Context, db DB) (*PreparedFetch[T], error) {
4✔
524
        var err error
4✔
525
        preparedFetch := &PreparedFetch[T]{
4✔
526
                compiledFetch: NewCompiledFetch(compiledFetch.GetSQL()),
4✔
527
        }
4✔
528
        preparedFetch.compiledFetch.queryIsStatic = compiledFetch.queryIsStatic
4✔
529
        if db == nil {
4✔
530
                return nil, fmt.Errorf("db is nil")
×
531
        }
×
532
        preparedFetch.stmt, err = db.PrepareContext(ctx, compiledFetch.query)
4✔
533
        if err != nil {
4✔
534
                return nil, err
×
535
        }
×
536
        preparedFetch.logger, _ = db.(SqLogger)
4✔
537
        if preparedFetch.logger == nil {
4✔
538
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
539
                if logQuery != nil {
×
540
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
541
                        preparedFetch.logger = &sqLogStruct{
×
542
                                logSettings: logSettings,
×
543
                                logQuery:    logQuery,
×
544
                        }
×
545
                }
×
546
        }
547
        return preparedFetch, nil
4✔
548
}
549

550
// PreparedFetch is the result of preparing a CompiledFetch on a DB.
551
type PreparedFetch[T any] struct {
552
        compiledFetch *CompiledFetch[T]
553
        stmt          *sql.Stmt
554
        logger        SqLogger
555
}
556

557
// PrepareFetch returns a new PreparedFetch.
558
func PrepareFetch[T any](db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
4✔
559
        return PrepareFetchContext(context.Background(), db, q, rowmapper)
4✔
560
}
4✔
561

562
// PrepareFetchContext is like PrepareFetch but additionally requires a context.Context.
563
func PrepareFetchContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
4✔
564
        compiledFetch, err := CompileFetchContext(ctx, q, rowmapper)
4✔
565
        if err != nil {
4✔
566
                return nil, err
×
567
        }
×
568
        return compiledFetch.PrepareContext(ctx, db)
4✔
569
}
570

571
// FetchCursor returns a new cursor.
572
func (preparedFetch PreparedFetch[T]) FetchCursor(params Params) (*Cursor[T], error) {
×
573
        return preparedFetch.fetchCursor(context.Background(), params, 1)
×
574
}
×
575

576
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
577
func (preparedFetch PreparedFetch[T]) FetchCursorContext(ctx context.Context, params Params) (*Cursor[T], error) {
×
578
        return preparedFetch.fetchCursor(ctx, params, 1)
×
579
}
×
580

581
func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params Params, skip int) (cursor *Cursor[T], err error) {
4✔
582
        cursor = &Cursor[T]{
4✔
583
                ctx:       ctx,
4✔
584
                rowmapper: preparedFetch.compiledFetch.rowmapper,
4✔
585
                row: &Row{
4✔
586
                        dialect:       preparedFetch.compiledFetch.dialect,
4✔
587
                        queryIsStatic: preparedFetch.compiledFetch.queryIsStatic,
4✔
588
                },
4✔
589
                queryStats: QueryStats{
4✔
590
                        Dialect:  preparedFetch.compiledFetch.dialect,
4✔
591
                        Query:    preparedFetch.compiledFetch.query,
4✔
592
                        Args:     preparedFetch.compiledFetch.args,
4✔
593
                        Params:   preparedFetch.compiledFetch.params,
4✔
594
                        RowCount: sql.NullInt64{Valid: true},
4✔
595
                },
4✔
596
                logger: preparedFetch.logger,
4✔
597
        }
4✔
598

4✔
599
        // If the query is dynamic, call the rowmapper to populate row.scanDest.
4✔
600
        if !cursor.row.queryIsStatic {
6✔
601
                defer mapperFunctionPanicked(&err)
2✔
602
                _ = cursor.rowmapper(cursor.row)
2✔
603
        }
2✔
604

605
        // Substitute params.
606
        cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params)
4✔
607
        if err != nil {
4✔
608
                return nil, err
×
609
        }
×
610

611
        // Setup logger.
612
        if cursor.logger != nil {
8✔
613
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
4✔
614
                if cursor.logSettings.IncludeCaller {
8✔
615
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
4✔
616
                }
4✔
617
        }
618

619
        // Run query.
620
        if cursor.logSettings.IncludeTime {
8✔
621
                cursor.queryStats.StartedAt = time.Now()
4✔
622
        }
4✔
623
        cursor.row.sqlRows, cursor.queryStats.Err = preparedFetch.stmt.QueryContext(ctx, cursor.queryStats.Args...)
4✔
624
        if cursor.logSettings.IncludeTime {
8✔
625
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
4✔
626
        }
4✔
627
        if cursor.queryStats.Err != nil {
4✔
628
                return nil, cursor.queryStats.Err
×
629
        }
×
630

631
        // If the query is static, we now know the number of columns returned by
632
        // the query and can allocate the values slice and scanDest slice for
633
        // scanning later.
634
        if cursor.row.queryIsStatic {
6✔
635
                cursor.row.columns, err = cursor.row.sqlRows.Columns()
2✔
636
                if err != nil {
2✔
NEW
637
                        return nil, err
×
NEW
638
                }
×
639
                cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes()
2✔
640
                if err != nil {
2✔
NEW
641
                        return nil, err
×
NEW
642
                }
×
643
                cursor.row.columnIndex = make(map[string]int)
2✔
644
                for index, column := range cursor.row.columns {
10✔
645
                        cursor.row.columnIndex[column] = index
8✔
646
                }
8✔
647
                cursor.row.values = make([]any, len(cursor.row.columns))
2✔
648
                cursor.row.scanDest = make([]any, len(cursor.row.columns))
2✔
649
                for index := range cursor.row.values {
10✔
650
                        cursor.row.scanDest[index] = &cursor.row.values[index]
8✔
651
                }
8✔
652
        }
653

654
        // Allocate the resultsBuffer.
655
        if cursor.logSettings.IncludeResults > 0 {
6✔
656
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
2✔
657
                cursor.resultsBuffer.Reset()
2✔
658
        }
2✔
659
        return cursor, nil
4✔
660
}
661

662
// FetchOne returns the first result from running the PreparedFetch with the
663
// give params.
664
func (preparedFetch *PreparedFetch[T]) FetchOne(params Params) (T, error) {
2✔
665
        cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1)
2✔
666
        if err != nil {
2✔
667
                return *new(T), err
×
668
        }
×
669
        defer cursor.Close()
2✔
670
        return cursorResult(cursor)
2✔
671
}
672

673
// FetchOneContext is like FetchOne but additionally requires a context.Context.
674
func (preparedFetch *PreparedFetch[T]) FetchOneContext(ctx context.Context, params Params) (T, error) {
×
675
        cursor, err := preparedFetch.fetchCursor(ctx, params, 1)
×
676
        if err != nil {
×
677
                return *new(T), err
×
678
        }
×
679
        defer cursor.Close()
×
680
        return cursorResult(cursor)
×
681
}
682

683
// FetchAll returns all the results from running the PreparedFetch with the
684
// give params.
685
func (preparedFetch *PreparedFetch[T]) FetchAll(params Params) ([]T, error) {
2✔
686
        cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1)
2✔
687
        if err != nil {
2✔
688
                return nil, err
×
689
        }
×
690
        defer cursor.Close()
2✔
691
        return cursorResults(cursor)
2✔
692
}
693

694
// FetchAllContext is like FetchAll but additionally requires a context.Context.
695
func (preparedFetch *PreparedFetch[T]) FetchAllContext(ctx context.Context, params Params) ([]T, error) {
×
696
        cursor, err := preparedFetch.fetchCursor(ctx, params, 1)
×
697
        if err != nil {
×
698
                return nil, err
×
699
        }
×
700
        defer cursor.Close()
×
701
        return cursorResults(cursor)
×
702
}
703

704
// GetCompiled returns a copy of the underlying CompiledFetch.
705
func (preparedFetch *PreparedFetch[T]) GetCompiled() *CompiledFetch[T] {
×
NEW
706
        compiledFetch := NewCompiledFetch(preparedFetch.compiledFetch.GetSQL())
×
NEW
707
        compiledFetch.queryIsStatic = preparedFetch.compiledFetch.queryIsStatic
×
NEW
708
        return compiledFetch
×
UNCOV
709
}
×
710

711
// Close closes the PreparedFetch.
712
func (preparedFetch *PreparedFetch[T]) Close() error {
×
713
        if preparedFetch.stmt == nil {
×
714
                return nil
×
715
        }
×
716
        return preparedFetch.stmt.Close()
×
717
}
718

719
// Exec executes the given Query on the given DB.
720
func Exec(db DB, query Query) (Result, error) {
9✔
721
        return exec(context.Background(), db, query, 1)
9✔
722
}
9✔
723

724
// ExecContext is like Exec but additionally requires a context.Context.
725
func ExecContext(ctx context.Context, db DB, query Query) (Result, error) {
×
726
        return exec(ctx, db, query, 1)
×
727
}
×
728

729
func exec(ctx context.Context, db DB, query Query, skip int) (result Result, err error) {
9✔
730
        if db == nil {
9✔
731
                return result, fmt.Errorf("db is nil")
×
732
        }
×
733
        if query == nil {
9✔
734
                return result, fmt.Errorf("query is nil")
×
735
        }
×
736
        dialect := query.GetDialect()
9✔
737
        if dialect == "" {
9✔
738
                defaultDialect := DefaultDialect.Load()
×
739
                if defaultDialect != nil {
×
740
                        dialect = *defaultDialect
×
741
                }
×
742
        }
743
        queryStats := QueryStats{
9✔
744
                Dialect: dialect,
9✔
745
                Params:  make(map[string][]int),
9✔
746
        }
9✔
747

9✔
748
        // Build query.
9✔
749
        buf := bufpool.Get().(*bytes.Buffer)
9✔
750
        buf.Reset()
9✔
751
        defer bufpool.Put(buf)
9✔
752
        err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params)
9✔
753
        queryStats.Query = buf.String()
9✔
754
        if err != nil {
9✔
755
                return result, err
×
756
        }
×
757

758
        // Setup logger.
759
        var logSettings LogSettings
9✔
760
        logger, _ := db.(SqLogger)
9✔
761
        if logger == nil {
9✔
762
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
763
                if logQuery != nil {
×
764
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
765
                        logger = &sqLogStruct{
×
766
                                logSettings: logSettings,
×
767
                                logQuery:    logQuery,
×
768
                        }
×
769
                }
×
770
        }
771
        if logger != nil {
18✔
772
                logger.SqLogSettings(ctx, &logSettings)
9✔
773
                if logSettings.IncludeCaller {
18✔
774
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
9✔
775
                }
9✔
776
                defer func() {
18✔
777
                        if logSettings.LogAsynchronously {
9✔
778
                                go logger.SqLogQuery(ctx, queryStats)
×
779
                        } else {
9✔
780
                                logger.SqLogQuery(ctx, queryStats)
9✔
781
                        }
9✔
782
                }()
783
        }
784

785
        // Run query.
786
        if logSettings.IncludeTime {
18✔
787
                queryStats.StartedAt = time.Now()
9✔
788
        }
9✔
789
        var sqlResult sql.Result
9✔
790
        sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...)
9✔
791
        if logSettings.IncludeTime {
18✔
792
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
9✔
793
        }
9✔
794
        if queryStats.Err != nil {
9✔
795
                return result, queryStats.Err
×
796
        }
×
797
        return execResult(sqlResult, &queryStats)
9✔
798
}
799

800
// CompiledExec is the result of compiling a Query down into a query string and
801
// args slice. A CompiledExec can be safely executed in parallel.
802
type CompiledExec struct {
803
        dialect string
804
        query   string
805
        args    []any
806
        params  map[string][]int
807
}
808

809
// NewCompiledExec returns a new CompiledExec.
810
func NewCompiledExec(dialect string, query string, args []any, params map[string][]int) *CompiledExec {
1✔
811
        return &CompiledExec{
1✔
812
                dialect: dialect,
1✔
813
                query:   query,
1✔
814
                args:    args,
1✔
815
                params:  params,
1✔
816
        }
1✔
817
}
1✔
818

819
// CompileExec returns a new CompiledExec.
820
func CompileExec(query Query) (*CompiledExec, error) {
1✔
821
        return CompileExecContext(context.Background(), query)
1✔
822
}
1✔
823

824
// CompileExecContext is like CompileExec but additionally requires a context.Context.
825
func CompileExecContext(ctx context.Context, query Query) (*CompiledExec, error) {
2✔
826
        if query == nil {
2✔
827
                return nil, fmt.Errorf("query is nil")
×
828
        }
×
829
        dialect := query.GetDialect()
2✔
830
        if dialect == "" {
2✔
831
                defaultDialect := DefaultDialect.Load()
×
832
                if defaultDialect != nil {
×
833
                        dialect = *defaultDialect
×
834
                }
×
835
        }
836
        compiledExec := &CompiledExec{
2✔
837
                dialect: dialect,
2✔
838
                params:  make(map[string][]int),
2✔
839
        }
2✔
840

2✔
841
        // Build query.
2✔
842
        buf := bufpool.Get().(*bytes.Buffer)
2✔
843
        buf.Reset()
2✔
844
        defer bufpool.Put(buf)
2✔
845
        err := query.WriteSQL(ctx, dialect, buf, &compiledExec.args, compiledExec.params)
2✔
846
        compiledExec.query = buf.String()
2✔
847
        if err != nil {
2✔
848
                return nil, err
×
849
        }
×
850
        return compiledExec, nil
2✔
851
}
852

853
// Exec executes the CompiledExec on the given DB with the given params.
854
func (compiledExec *CompiledExec) Exec(db DB, params Params) (Result, error) {
5✔
855
        return compiledExec.exec(context.Background(), db, params, 1)
5✔
856
}
5✔
857

858
// ExecContext is like Exec but additionally requires a context.Context.
859
func (compiledExec *CompiledExec) ExecContext(ctx context.Context, db DB, params Params) (Result, error) {
×
860
        return compiledExec.exec(ctx, db, params, 1)
×
861
}
×
862

863
func (compiledExec *CompiledExec) exec(ctx context.Context, db DB, params Params, skip int) (result Result, err error) {
5✔
864
        if db == nil {
5✔
865
                return result, fmt.Errorf("db is nil")
×
866
        }
×
867
        queryStats := QueryStats{
5✔
868
                Dialect: compiledExec.dialect,
5✔
869
                Query:   compiledExec.query,
5✔
870
                Args:    compiledExec.args,
5✔
871
                Params:  compiledExec.params,
5✔
872
        }
5✔
873

5✔
874
        // Setup logger.
5✔
875
        var logSettings LogSettings
5✔
876
        logger, _ := db.(SqLogger)
5✔
877
        if logger == nil {
5✔
878
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
879
                if logQuery != nil {
×
880
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
881
                        logger = &sqLogStruct{
×
882
                                logSettings: logSettings,
×
883
                                logQuery:    logQuery,
×
884
                        }
×
885
                }
×
886
        }
887
        if logger != nil {
10✔
888
                logger.SqLogSettings(ctx, &logSettings)
5✔
889
                if logSettings.IncludeCaller {
10✔
890
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
5✔
891
                }
5✔
892
                defer func() {
10✔
893
                        if logSettings.LogAsynchronously {
5✔
894
                                go logger.SqLogQuery(ctx, queryStats)
×
895
                        } else {
5✔
896
                                logger.SqLogQuery(ctx, queryStats)
5✔
897
                        }
5✔
898
                }()
899
        }
900

901
        // Substitute params.
902
        queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params)
5✔
903
        if err != nil {
5✔
904
                return result, err
×
905
        }
×
906

907
        // Run query.
908
        if logSettings.IncludeTime {
10✔
909
                queryStats.StartedAt = time.Now()
5✔
910
        }
5✔
911
        var sqlResult sql.Result
5✔
912
        sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...)
5✔
913
        if logSettings.IncludeTime {
10✔
914
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
5✔
915
        }
5✔
916
        if queryStats.Err != nil {
5✔
917
                return result, queryStats.Err
×
918
        }
×
919
        return execResult(sqlResult, &queryStats)
5✔
920
}
921

922
// GetSQL returns a copy of the dialect, query, args, params and rowmapper that
923
// make up the CompiledExec.
924
func (compiledExec *CompiledExec) GetSQL() (dialect string, query string, args []any, params map[string][]int) {
1✔
925
        dialect = compiledExec.dialect
1✔
926
        query = compiledExec.query
1✔
927
        args = make([]any, len(compiledExec.args))
1✔
928
        params = make(map[string][]int)
1✔
929
        copy(args, compiledExec.args)
1✔
930
        for name, indexes := range compiledExec.params {
5✔
931
                indexes2 := make([]int, len(indexes))
4✔
932
                copy(indexes2, indexes)
4✔
933
                params[name] = indexes2
4✔
934
        }
4✔
935
        return dialect, query, args, params
1✔
936
}
937

938
// Prepare creates a PreparedExec from a CompiledExec by preparing it on the
939
// given DB.
940
func (compiledExec *CompiledExec) Prepare(db DB) (*PreparedExec, error) {
×
941
        return compiledExec.PrepareContext(context.Background(), db)
×
942
}
×
943

944
// PrepareContext is like Prepare but additionally requires a context.Context.
945
func (compiledExec *CompiledExec) PrepareContext(ctx context.Context, db DB) (*PreparedExec, error) {
1✔
946
        var err error
1✔
947
        preparedExec := &PreparedExec{
1✔
948
                compiledExec: NewCompiledExec(compiledExec.GetSQL()),
1✔
949
        }
1✔
950
        preparedExec.stmt, err = db.PrepareContext(ctx, compiledExec.query)
1✔
951
        if err != nil {
1✔
952
                return nil, err
×
953
        }
×
954
        preparedExec.logger, _ = db.(SqLogger)
1✔
955
        if preparedExec.logger == nil {
1✔
956
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
957
                if logQuery != nil {
×
958
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
959
                        preparedExec.logger = &sqLogStruct{
×
960
                                logSettings: logSettings,
×
961
                                logQuery:    logQuery,
×
962
                        }
×
963
                }
×
964
        }
965
        return preparedExec, nil
1✔
966
}
967

968
// PrepareExec is the result of preparing a CompiledExec on a DB.
969
type PreparedExec struct {
970
        compiledExec *CompiledExec
971
        stmt         *sql.Stmt
972
        logger       SqLogger
973
}
974

975
// PrepareExec returns a new PreparedExec.
976
func PrepareExec(db DB, q Query) (*PreparedExec, error) {
1✔
977
        return PrepareExecContext(context.Background(), db, q)
1✔
978
}
1✔
979

980
// PrepareExecContext is like PrepareExec but additionally requires a
981
// context.Context.
982
func PrepareExecContext(ctx context.Context, db DB, q Query) (*PreparedExec, error) {
1✔
983
        compiledExec, err := CompileExecContext(ctx, q)
1✔
984
        if err != nil {
1✔
985
                return nil, err
×
986
        }
×
987
        return compiledExec.PrepareContext(ctx, db)
1✔
988
}
989

990
// Close closes the PreparedExec.
991
func (preparedExec *PreparedExec) Close() error {
×
992
        if preparedExec.stmt == nil {
×
993
                return nil
×
994
        }
×
995
        return preparedExec.stmt.Close()
×
996
}
997

998
// Exec executes the PreparedExec with the given params.
999
func (preparedExec *PreparedExec) Exec(params Params) (Result, error) {
5✔
1000
        return preparedExec.exec(context.Background(), params, 1)
5✔
1001
}
5✔
1002

1003
// ExecContext is like Exec but additionally requires a context.Context.
1004
func (preparedExec *PreparedExec) ExecContext(ctx context.Context, params Params) (Result, error) {
×
1005
        return preparedExec.exec(ctx, params, 1)
×
1006
}
×
1007

1008
func (preparedExec *PreparedExec) exec(ctx context.Context, params Params, skip int) (result Result, err error) {
5✔
1009
        queryStats := QueryStats{
5✔
1010
                Dialect: preparedExec.compiledExec.dialect,
5✔
1011
                Query:   preparedExec.compiledExec.query,
5✔
1012
                Args:    preparedExec.compiledExec.args,
5✔
1013
                Params:  preparedExec.compiledExec.params,
5✔
1014
        }
5✔
1015

5✔
1016
        // Setup logger.
5✔
1017
        var logSettings LogSettings
5✔
1018
        if preparedExec.logger != nil {
10✔
1019
                preparedExec.logger.SqLogSettings(ctx, &logSettings)
5✔
1020
                if logSettings.IncludeCaller {
10✔
1021
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
5✔
1022
                }
5✔
1023
                defer func() {
10✔
1024
                        if logSettings.LogAsynchronously {
5✔
1025
                                go preparedExec.logger.SqLogQuery(ctx, queryStats)
×
1026
                        } else {
5✔
1027
                                preparedExec.logger.SqLogQuery(ctx, queryStats)
5✔
1028
                        }
5✔
1029
                }()
1030
        }
1031

1032
        // Substitute params.
1033
        queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params)
5✔
1034
        if err != nil {
5✔
1035
                return result, err
×
1036
        }
×
1037

1038
        // Run query.
1039
        if logSettings.IncludeTime {
10✔
1040
                queryStats.StartedAt = time.Now()
5✔
1041
        }
5✔
1042
        var sqlResult sql.Result
5✔
1043
        sqlResult, queryStats.Err = preparedExec.stmt.ExecContext(ctx, queryStats.Args...)
5✔
1044
        if logSettings.IncludeTime {
10✔
1045
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
5✔
1046
        }
5✔
1047
        if queryStats.Err != nil {
5✔
1048
                return result, queryStats.Err
×
1049
        }
×
1050
        return execResult(sqlResult, &queryStats)
5✔
1051
}
1052

1053
func getFieldNames(ctx context.Context, row *Row) []string {
10✔
1054
        if len(row.fields) == 0 {
13✔
1055
                columns, _ := row.sqlRows.Columns()
3✔
1056
                return columns
3✔
1057
        }
3✔
1058
        buf := bufpool.Get().(*bytes.Buffer)
7✔
1059
        buf.Reset()
7✔
1060
        defer bufpool.Put(buf)
7✔
1061
        var args []any
7✔
1062
        fieldNames := make([]string, 0, len(row.fields))
7✔
1063
        for _, field := range row.fields {
91✔
1064
                if alias := getAlias(field); alias != "" {
84✔
1065
                        fieldNames = append(fieldNames, alias)
×
1066
                        continue
×
1067
                }
1068
                buf.Reset()
84✔
1069
                args = args[:0]
84✔
1070
                err := field.WriteSQL(ctx, row.dialect, buf, &args, nil)
84✔
1071
                if err != nil {
84✔
1072
                        fieldNames = append(fieldNames, "%!(error="+err.Error()+")")
×
1073
                        continue
×
1074
                }
1075
                fieldName, err := Sprintf(row.dialect, buf.String(), args)
84✔
1076
                if err != nil {
84✔
1077
                        fieldNames = append(fieldNames, "%!(error="+err.Error()+")")
×
1078
                        continue
×
1079
                }
1080
                fieldNames = append(fieldNames, fieldName)
84✔
1081
        }
1082
        return fieldNames
7✔
1083
}
1084

1085
func getFieldMappings(dialect string, fields []Field, scanDest []any) string {
2✔
1086
        var buf bytes.Buffer
2✔
1087
        var args []any
2✔
1088
        var b strings.Builder
2✔
1089
        for i, field := range fields {
5✔
1090
                b.WriteString(fmt.Sprintf("\n %02d. ", i+1))
3✔
1091
                buf.Reset()
3✔
1092
                args = args[:0]
3✔
1093
                err := field.WriteSQL(context.Background(), dialect, &buf, &args, nil)
3✔
1094
                if err != nil {
3✔
1095
                        buf.WriteString("%!(error=" + err.Error() + ")")
×
1096
                        continue
×
1097
                }
1098
                fieldName, err := Sprintf(dialect, buf.String(), args)
3✔
1099
                if err != nil {
3✔
1100
                        b.WriteString("%!(error=" + err.Error() + ")")
×
1101
                        continue
×
1102
                }
1103
                b.WriteString(fieldName + " => " + reflect.TypeOf(scanDest[i]).String())
3✔
1104
        }
1105
        return b.String()
2✔
1106
}
1107

1108
// TODO: inline cursorResult, cursorResults and execResult.
1109

1110
func cursorResult[T any](cursor *Cursor[T]) (result T, err error) {
22✔
1111
        for cursor.Next() {
44✔
1112
                result, err = cursor.Result()
22✔
1113
                if err != nil {
22✔
1114
                        return result, err
×
1115
                }
×
1116
                break
22✔
1117
        }
1118
        if cursor.RowCount() == 0 {
22✔
1119
                return result, sql.ErrNoRows
×
1120
        }
×
1121
        return result, cursor.Close()
22✔
1122
}
1123

1124
func cursorResults[T any](cursor *Cursor[T]) (results []T, err error) {
22✔
1125
        var result T
22✔
1126
        for cursor.Next() {
100✔
1127
                result, err = cursor.Result()
78✔
1128
                if err != nil {
78✔
1129
                        return results, err
×
1130
                }
×
1131
                results = append(results, result)
78✔
1132
        }
1133
        return results, cursor.Close()
22✔
1134
}
1135

1136
func execResult(sqlResult sql.Result, queryStats *QueryStats) (Result, error) {
19✔
1137
        var err error
19✔
1138
        var result Result
19✔
1139
        if queryStats.Dialect == DialectSQLite || queryStats.Dialect == DialectMySQL {
34✔
1140
                result.LastInsertId, err = sqlResult.LastInsertId()
15✔
1141
                if err != nil {
15✔
1142
                        return result, err
×
1143
                }
×
1144
                queryStats.LastInsertId.Valid = true
15✔
1145
                queryStats.LastInsertId.Int64 = result.LastInsertId
15✔
1146
        }
1147
        result.RowsAffected, err = sqlResult.RowsAffected()
19✔
1148
        if err != nil {
19✔
1149
                return result, err
×
1150
        }
×
1151
        queryStats.RowsAffected.Valid = true
19✔
1152
        queryStats.RowsAffected.Int64 = result.RowsAffected
19✔
1153
        return result, nil
19✔
1154
}
1155

1156
// FetchExists returns a boolean indicating if running the given Query on the
1157
// given DB returned any results.
1158
func FetchExists(db DB, query Query) (exists bool, err error) {
4✔
1159
        return fetchExists(context.Background(), db, query, 1)
4✔
1160
}
4✔
1161

1162
// FetchExistsContext is like FetchExists but additionally requires a
1163
// context.Context.
1164
func FetchExistsContext(ctx context.Context, db DB, query Query) (exists bool, err error) {
×
1165
        return fetchExists(ctx, db, query, 1)
×
1166
}
×
1167

1168
func fetchExists(ctx context.Context, db DB, query Query, skip int) (exists bool, err error) {
4✔
1169
        dialect := query.GetDialect()
4✔
1170
        if dialect == "" {
4✔
1171
                defaultDialect := DefaultDialect.Load()
×
1172
                if defaultDialect != nil {
×
1173
                        dialect = *defaultDialect
×
1174
                }
×
1175
        }
1176
        queryStats := QueryStats{
4✔
1177
                Dialect: dialect,
4✔
1178
                Params:  make(map[string][]int),
4✔
1179
                Exists:  sql.NullBool{Valid: true},
4✔
1180
        }
4✔
1181

4✔
1182
        // Build query.
4✔
1183
        buf := bufpool.Get().(*bytes.Buffer)
4✔
1184
        buf.Reset()
4✔
1185
        defer bufpool.Put(buf)
4✔
1186
        if dialect == DialectSQLServer {
5✔
1187
                query = Queryf("SELECT CASE WHEN EXISTS ({}) THEN 1 ELSE 0 END", query)
1✔
1188
        } else {
4✔
1189
                query = Queryf("SELECT EXISTS ({})", query)
3✔
1190
        }
3✔
1191
        err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params)
4✔
1192
        queryStats.Query = buf.String()
4✔
1193
        if err != nil {
4✔
1194
                return false, err
×
1195
        }
×
1196

1197
        // Setup logger.
1198
        var logSettings LogSettings
4✔
1199
        logger, _ := db.(SqLogger)
4✔
1200
        if logger == nil {
4✔
1201
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
1202
                if logQuery != nil {
×
1203
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
1204
                        logger = &sqLogStruct{
×
1205
                                logSettings: logSettings,
×
1206
                                logQuery:    logQuery,
×
1207
                        }
×
1208
                }
×
1209
        }
1210
        if logger != nil {
8✔
1211
                logger.SqLogSettings(ctx, &logSettings)
4✔
1212
                if logSettings.IncludeCaller {
8✔
1213
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
4✔
1214
                }
4✔
1215
                defer func() {
8✔
1216
                        if logSettings.LogAsynchronously {
4✔
1217
                                go logger.SqLogQuery(ctx, queryStats)
×
1218
                        } else {
4✔
1219
                                logger.SqLogQuery(ctx, queryStats)
4✔
1220
                        }
4✔
1221
                }()
1222
        }
1223

1224
        // Run query.
1225
        if logSettings.IncludeTime {
8✔
1226
                queryStats.StartedAt = time.Now()
4✔
1227
        }
4✔
1228
        var sqlRows *sql.Rows
4✔
1229
        sqlRows, queryStats.Err = db.QueryContext(ctx, queryStats.Query, queryStats.Args...)
4✔
1230
        if logSettings.IncludeTime {
8✔
1231
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
4✔
1232
        }
4✔
1233
        if queryStats.Err != nil {
4✔
1234
                return false, queryStats.Err
×
1235
        }
×
1236

1237
        for sqlRows.Next() {
8✔
1238
                err = sqlRows.Scan(&exists)
4✔
1239
                if err != nil {
4✔
1240
                        return false, err
×
1241
                }
×
1242
                break
4✔
1243
        }
1244
        queryStats.Exists.Bool = exists
4✔
1245

4✔
1246
        if err := sqlRows.Close(); err != nil {
4✔
1247
                return exists, err
×
1248
        }
×
1249
        if err := sqlRows.Err(); err != nil {
4✔
1250
                return exists, err
×
1251
        }
×
1252
        return exists, nil
4✔
1253
}
1254

1255
// substituteParams will return a new args slice by substituting values from
1256
// the given paramValues. The input args slice is untouched.
1257
func substituteParams(dialect string, args []any, paramIndexes map[string][]int, paramValues map[string]any) ([]any, error) {
21✔
1258
        if len(paramValues) == 0 {
26✔
1259
                return args, nil
5✔
1260
        }
5✔
1261
        newArgs := make([]any, len(args))
16✔
1262
        copy(newArgs, args)
16✔
1263
        var err error
16✔
1264
        for name, value := range paramValues {
65✔
1265
                indexes := paramIndexes[name]
49✔
1266
                for _, index := range indexes {
98✔
1267
                        switch arg := newArgs[index].(type) {
49✔
1268
                        case sql.NamedArg:
46✔
1269
                                arg.Value, err = preprocessValue(dialect, value)
46✔
1270
                                if err != nil {
46✔
1271
                                        return nil, err
×
1272
                                }
×
1273
                                newArgs[index] = arg
46✔
1274
                        default:
3✔
1275
                                value, err = preprocessValue(dialect, value)
3✔
1276
                                if err != nil {
3✔
1277
                                        return nil, err
×
1278
                                }
×
1279
                                newArgs[index] = value
3✔
1280
                        }
1281
                }
1282
        }
1283
        return newArgs, nil
16✔
1284
}
1285

1286
func caller(skip int) (file string, line int, function string) {
43✔
1287
        pc, file, line, _ := runtime.Caller(skip + 1)
43✔
1288
        fn := runtime.FuncForPC(pc)
43✔
1289
        function = fn.Name()
43✔
1290
        return file, line, function
43✔
1291
}
43✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc