• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bokwoon95 / sq / 5783018144

pending completion
5783018144

push

github

bokwoon95
Add option to set default logger and dialect for queries.

1. Add sq.DefaultDialect (an atomic.Pointer[string]) to set the default
   dialect to be used for query building when the passed in dialect is
   empty. Fixes #7.
2. Add SetDefaultLogQuery and SetDefaulLogSettings to configure the
   default logging function (and LogSettings) to be used for all queries
   is a logger is not explicitly passed in.
3. Don't check for ctx.Done() inside the logger, this was causing
   certain queries to not be logged when LogSettings.LogAsynchronously
   was set to true and the request completed before the query could be
   logged.

114 of 114 new or added lines in 2 files covered. (100.0%)

5401 of 6347 relevant lines covered (85.1%)

48.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.73
/fetch_exec.go
1
package sq
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "fmt"
8
        "reflect"
9
        "runtime"
10
        "strconv"
11
        "strings"
12
        "sync/atomic"
13
        "time"
14
)
15

16
var (
17
        errMixedCalls       = fmt.Errorf("rowmapper cannot mix calls to row.Values()/row.Columns()/row.ColumnTypes() with the other row methods")
18
        errNoFieldsAccessed = fmt.Errorf("rowmapper did not access any fields, unable to determine fields to insert into query")
19
        errForbiddenCalls   = fmt.Errorf("rowmapper can only contain calls to row.Values()/row.Columns()/row.ColumnTypes() because query's SELECT clause is not dynamic")
20
)
21

22
// Default dialect used by all queries (if no dialect is explicitly provided).
23
var DefaultDialect atomic.Pointer[string]
24

25
// A Cursor represents a database cursor.
26
type Cursor[T any] struct {
27
        ctx           context.Context
28
        row           *Row
29
        rowmapper     func(*Row) T
30
        queryStats    QueryStats
31
        logSettings   LogSettings
32
        logger        SqLogger
33
        logged        int32
34
        fieldNames    []string
35
        resultsBuffer *bytes.Buffer
36
}
37

38
// FetchCursor returns a new cursor.
39
func FetchCursor[T any](db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) {
×
40
        return fetchCursor(context.Background(), db, query, rowmapper, 1)
×
41
}
×
42

43
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
44
func FetchCursorContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) {
×
45
        return fetchCursor(ctx, db, query, rowmapper, 1)
×
46
}
×
47

48
func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T, skip int) (cursor *Cursor[T], err error) {
40✔
49
        if db == nil {
40✔
50
                return nil, fmt.Errorf("db is nil")
×
51
        }
×
52
        if query == nil {
40✔
53
                return nil, fmt.Errorf("query is nil")
×
54
        }
×
55
        if rowmapper == nil {
40✔
56
                return nil, fmt.Errorf("rowmapper is nil")
×
57
        }
×
58
        dialect := query.GetDialect()
40✔
59
        if dialect == "" {
40✔
60
                defaultDialect := DefaultDialect.Load()
×
61
                if defaultDialect != nil {
×
62
                        dialect = *defaultDialect
×
63
                }
×
64
        }
65
        cursor = &Cursor[T]{
40✔
66
                ctx:       ctx,
40✔
67
                rowmapper: rowmapper,
40✔
68
                row: &Row{
40✔
69
                        dialect: dialect,
40✔
70
                },
40✔
71
                queryStats: QueryStats{
40✔
72
                        Dialect:  dialect,
40✔
73
                        RowCount: sql.NullInt64{Valid: true},
40✔
74
                        Params:   make(map[string][]int),
40✔
75
                },
40✔
76
        }
40✔
77

40✔
78
        // Call the rowmapper to populate row.fields and row.scanDest.
40✔
79
        defer mapperFunctionPanicked(&err)
40✔
80
        _ = cursor.rowmapper(cursor.row)
40✔
81
        var ok bool
40✔
82
        if cursor.row.rawSQLMode && len(cursor.row.fields) > 0 {
44✔
83
                return nil, errMixedCalls
4✔
84
        }
4✔
85

86
        // Insert the fields into the query.
87
        query, ok = query.SetFetchableFields(cursor.row.fields)
36✔
88
        if ok && len(cursor.row.fields) == 0 {
40✔
89
                return nil, errNoFieldsAccessed
4✔
90
        }
4✔
91
        if !ok && len(cursor.row.fields) > 0 {
36✔
92
                return nil, errForbiddenCalls
4✔
93
        }
4✔
94

95
        // Build query.
96
        buf := bufpool.Get().(*bytes.Buffer)
28✔
97
        buf.Reset()
28✔
98
        defer bufpool.Put(buf)
28✔
99
        err = query.WriteSQL(ctx, dialect, buf, &cursor.queryStats.Args, cursor.queryStats.Params)
28✔
100
        cursor.queryStats.Query = buf.String()
28✔
101
        if err != nil {
28✔
102
                return nil, err
×
103
        }
×
104

105
        // Setup logger.
106
        cursor.logger, _ = db.(SqLogger)
28✔
107
        if cursor.logger == nil {
48✔
108
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
20✔
109
                if logQuery != nil {
20✔
110
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
111
                        cursor.logger = &sqLogStruct{
×
112
                                logSettings: logSettings,
×
113
                                logQuery:    logQuery,
×
114
                        }
×
115
                }
×
116
        }
117
        if cursor.logger != nil {
36✔
118
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
8✔
119
                if cursor.logSettings.IncludeCaller {
16✔
120
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
8✔
121
                }
8✔
122
        }
123

124
        // Run query.
125
        if cursor.logSettings.IncludeTime {
36✔
126
                cursor.queryStats.StartedAt = time.Now()
8✔
127
        }
8✔
128
        cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...)
28✔
129
        if cursor.logSettings.IncludeTime {
36✔
130
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
8✔
131
        }
8✔
132
        if cursor.queryStats.Err != nil {
28✔
133
                cursor.log()
×
134
                return nil, cursor.queryStats.Err
×
135
        }
×
136

137
        // Allocate the resultsBuffer.
138
        if cursor.logSettings.IncludeResults > 0 {
34✔
139
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
6✔
140
                cursor.resultsBuffer.Reset()
6✔
141
        }
6✔
142
        return cursor, nil
28✔
143
}
144

145
// Next advances the cursor to the next result.
146
func (cursor *Cursor[T]) Next() bool {
102✔
147
        hasNext := cursor.row.sqlRows.Next()
102✔
148
        if hasNext {
186✔
149
                cursor.queryStats.RowCount.Int64++
84✔
150
        } else {
102✔
151
                cursor.log()
18✔
152
        }
18✔
153
        return hasNext
102✔
154
}
155

156
// RowCount returns the current row number so far.
157
func (cursor *Cursor[T]) RowCount() int64 { return cursor.queryStats.RowCount.Int64 }
18✔
158

159
// Result returns the cursor result.
160
func (cursor *Cursor[T]) Result() (result T, err error) {
84✔
161
        if !cursor.row.rawSQLMode {
134✔
162
                err = cursor.row.sqlRows.Scan(cursor.row.scanDest...)
50✔
163
                if err != nil {
50✔
164
                        cursor.log()
×
165
                        fieldMappings := getFieldMappings(cursor.queryStats.Dialect, cursor.row.fields, cursor.row.scanDest)
×
166
                        return result, fmt.Errorf("please check if your mapper function is correct:%s\n%w", fieldMappings, err)
×
167
                }
×
168
        }
169
        // If results should be logged, write the row into the resultsBuffer.
170
        if cursor.resultsBuffer != nil && cursor.queryStats.RowCount.Int64 <= int64(cursor.logSettings.IncludeResults) {
126✔
171
                if len(cursor.fieldNames) == 0 {
52✔
172
                        cursor.fieldNames = getFieldNames(cursor.ctx, cursor.row)
10✔
173
                }
10✔
174
                cursor.resultsBuffer.WriteString("\n----[ Row " + strconv.FormatInt(cursor.queryStats.RowCount.Int64, 10) + " ]----")
42✔
175
                for i := range cursor.row.scanDest {
366✔
176
                        cursor.resultsBuffer.WriteString("\n")
324✔
177
                        if i < len(cursor.fieldNames) {
648✔
178
                                cursor.resultsBuffer.WriteString(cursor.fieldNames[i])
324✔
179
                        }
324✔
180
                        cursor.resultsBuffer.WriteString(": ")
324✔
181
                        scanDest := cursor.row.scanDest[i]
324✔
182
                        rhs, err := Sprint(cursor.queryStats.Dialect, scanDest)
324✔
183
                        if err != nil {
324✔
184
                                cursor.resultsBuffer.WriteString("%!(error=" + err.Error() + ")")
×
185
                                continue
×
186
                        }
187
                        cursor.resultsBuffer.WriteString(rhs)
324✔
188
                }
189
        }
190
        cursor.row.index = 0
84✔
191
        defer mapperFunctionPanicked(&err)
84✔
192
        result = cursor.rowmapper(cursor.row)
84✔
193
        return result, nil
84✔
194
}
195

196
func (cursor *Cursor[T]) log() {
90✔
197
        if !atomic.CompareAndSwapInt32(&cursor.logged, 0, 1) {
144✔
198
                return
54✔
199
        }
54✔
200
        if cursor.resultsBuffer != nil {
46✔
201
                cursor.queryStats.Results = cursor.resultsBuffer.String()
10✔
202
                bufpool.Put(cursor.resultsBuffer)
10✔
203
        }
10✔
204
        if cursor.logger == nil {
56✔
205
                return
20✔
206
        }
20✔
207
        if cursor.logSettings.LogAsynchronously {
16✔
208
                go cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats)
×
209
        } else {
16✔
210
                cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats)
16✔
211
        }
16✔
212
}
213

214
// Close closes the cursor.
215
func (cursor *Cursor[T]) Close() error {
72✔
216
        cursor.log()
72✔
217
        if err := cursor.row.sqlRows.Close(); err != nil {
72✔
218
                return err
×
219
        }
×
220
        if err := cursor.row.sqlRows.Err(); err != nil {
72✔
221
                return err
×
222
        }
×
223
        return nil
72✔
224
}
225

226
// FetchOne returns the first result from running the given Query on the given
227
// DB.
228
func FetchOne[T any](db DB, query Query, rowmapper func(*Row) T) (T, error) {
14✔
229
        cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1)
14✔
230
        if err != nil {
14✔
231
                return *new(T), err
×
232
        }
×
233
        defer cursor.Close()
14✔
234
        return cursorResult(cursor)
14✔
235
}
236

237
// FetchOneContext is like FetchOne but additionally requires a context.Context.
238
func FetchOneContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (T, error) {
×
239
        cursor, err := fetchCursor(ctx, db, query, rowmapper, 1)
×
240
        if err != nil {
×
241
                return *new(T), err
×
242
        }
×
243
        defer cursor.Close()
×
244
        return cursorResult(cursor)
×
245
}
246

247
// FetchAll returns all results from running the given Query on the given DB.
248
func FetchAll[T any](db DB, query Query, rowmapper func(*Row) T) ([]T, error) {
26✔
249
        cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1)
26✔
250
        if err != nil {
38✔
251
                return nil, err
12✔
252
        }
12✔
253
        defer cursor.Close()
14✔
254
        return cursorResults(cursor)
14✔
255
}
256

257
// FetchAllContext is like FetchAll but additionally requires a context.Context.
258
func FetchAllContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) ([]T, error) {
×
259
        cursor, err := fetchCursor(ctx, db, query, rowmapper, 1)
×
260
        if err != nil {
×
261
                return nil, err
×
262
        }
×
263
        defer cursor.Close()
×
264
        return cursorResults(cursor)
×
265
}
266

267
// CompiledFetch is the result of compiling a Query down into a query string
268
// and args slice. A CompiledFetch can be safely executed in parallel.
269
type CompiledFetch[T any] struct {
270
        dialect   string
271
        query     string
272
        args      []any
273
        params    map[string][]int
274
        rowmapper func(*Row) T
275
}
276

277
// NewCompiledFetch returns a new CompiledFetch.
278
func NewCompiledFetch[T any](dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) *CompiledFetch[T] {
4✔
279
        return &CompiledFetch[T]{
4✔
280
                dialect:   dialect,
4✔
281
                query:     query,
4✔
282
                args:      args,
4✔
283
                params:    params,
4✔
284
                rowmapper: rowmapper,
4✔
285
        }
4✔
286
}
4✔
287

288
// CompileFetch returns a new CompileFetch.
289
func CompileFetch[T any](q Query, rowmapper func(*Row) T) (*CompiledFetch[T], error) {
4✔
290
        return CompileFetchContext(context.Background(), q, rowmapper)
4✔
291
}
4✔
292

293
// CompileFetchContext is like CompileFetch but accepts a context.Context.
294
func CompileFetchContext[T any](ctx context.Context, query Query, rowmapper func(*Row) T) (compiledFetch *CompiledFetch[T], err error) {
8✔
295
        if query == nil {
8✔
296
                return nil, fmt.Errorf("query is nil")
×
297
        }
×
298
        if rowmapper == nil {
8✔
299
                return nil, fmt.Errorf("rowmapper is nil")
×
300
        }
×
301
        dialect := query.GetDialect()
8✔
302
        if dialect == "" {
8✔
303
                defaultDialect := DefaultDialect.Load()
×
304
                if defaultDialect != nil {
×
305
                        dialect = *defaultDialect
×
306
                }
×
307
        }
308
        compiledFetch = &CompiledFetch[T]{
8✔
309
                dialect:   dialect,
8✔
310
                params:    make(map[string][]int),
8✔
311
                rowmapper: rowmapper,
8✔
312
        }
8✔
313
        row := &Row{
8✔
314
                dialect: dialect,
8✔
315
        }
8✔
316

8✔
317
        // Call the rowmapper to populate row.fields.
8✔
318
        defer mapperFunctionPanicked(&err)
8✔
319
        _ = rowmapper(row)
8✔
320
        var ok bool
8✔
321
        if row.rawSQLMode && len(row.fields) > 0 {
8✔
322
                return nil, errMixedCalls
×
323
        }
×
324

325
        // Insert the fields into the query.
326
        query, ok = query.SetFetchableFields(row.fields)
8✔
327
        if ok && len(row.fields) == 0 {
8✔
328
                return nil, errNoFieldsAccessed
×
329
        }
×
330
        if !ok && len(row.fields) > 0 {
8✔
331
                return nil, errForbiddenCalls
×
332
        }
×
333

334
        // Build query.
335
        buf := bufpool.Get().(*bytes.Buffer)
8✔
336
        buf.Reset()
8✔
337
        defer bufpool.Put(buf)
8✔
338
        err = query.WriteSQL(ctx, dialect, buf, &compiledFetch.args, compiledFetch.params)
8✔
339
        compiledFetch.query = buf.String()
8✔
340
        if err != nil {
8✔
341
                return nil, err
×
342
        }
×
343
        return compiledFetch, nil
8✔
344
}
345

346
// FetchCursor returns a new cursor.
347
func (compiledFetch *CompiledFetch[T]) FetchCursor(db DB, params Params) (*Cursor[T], error) {
×
348
        return compiledFetch.fetchCursor(context.Background(), db, params, 1)
×
349
}
×
350

351
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
352
func (compiledFetch *CompiledFetch[T]) FetchCursorContext(ctx context.Context, db DB, params Params) (*Cursor[T], error) {
×
353
        return compiledFetch.fetchCursor(ctx, db, params, 1)
×
354
}
×
355

356
func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, params Params, skip int) (cursor *Cursor[T], err error) {
4✔
357
        if db == nil {
4✔
358
                return nil, fmt.Errorf("db is nil")
×
359
        }
×
360
        cursor = &Cursor[T]{
4✔
361
                ctx:       ctx,
4✔
362
                rowmapper: compiledFetch.rowmapper,
4✔
363
                row: &Row{
4✔
364
                        dialect: compiledFetch.dialect,
4✔
365
                },
4✔
366
                queryStats: QueryStats{
4✔
367
                        Dialect: compiledFetch.dialect,
4✔
368
                        Query:   compiledFetch.query,
4✔
369
                        Args:    compiledFetch.args,
4✔
370
                        Params:  compiledFetch.params,
4✔
371
                },
4✔
372
        }
4✔
373

4✔
374
        // Call the rowmapper to populate row.scanDest.
4✔
375
        defer mapperFunctionPanicked(&err)
4✔
376
        _ = cursor.rowmapper(cursor.row)
4✔
377
        if err != nil {
4✔
378
                return nil, err
×
379
        }
×
380

381
        // Substitute params.
382
        cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params)
4✔
383
        if err != nil {
4✔
384
                return nil, err
×
385
        }
×
386

387
        // Setup logger.
388
        cursor.queryStats.RowCount.Valid = true
4✔
389
        cursor.logger, _ = db.(SqLogger)
4✔
390
        if cursor.logger == nil {
4✔
391
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
392
                if logQuery != nil {
×
393
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
394
                        cursor.logger = &sqLogStruct{
×
395
                                logSettings: logSettings,
×
396
                                logQuery:    logQuery,
×
397
                        }
×
398
                }
×
399
        }
400
        if cursor.logger != nil {
8✔
401
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
4✔
402
                if cursor.logSettings.IncludeCaller {
8✔
403
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
4✔
404
                }
4✔
405
        }
406

407
        // Run query.
408
        if cursor.logSettings.IncludeTime {
8✔
409
                cursor.queryStats.StartedAt = time.Now()
4✔
410
        }
4✔
411
        cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...)
4✔
412
        if cursor.logSettings.IncludeTime {
8✔
413
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
4✔
414
        }
4✔
415
        if cursor.queryStats.Err != nil {
4✔
416
                return nil, cursor.queryStats.Err
×
417
        }
×
418

419
        // Allocate the resultsBuffer.
420
        if cursor.logSettings.IncludeResults > 0 {
6✔
421
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
2✔
422
                cursor.resultsBuffer.Reset()
2✔
423
        }
2✔
424
        return cursor, nil
4✔
425
}
426

427
// FetchOne returns the first result from running the CompiledFetch on the
428
// given DB with the give params.
429
func (compiledFetch *CompiledFetch[T]) FetchOne(db DB, params Params) (T, error) {
2✔
430
        cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1)
2✔
431
        if err != nil {
2✔
432
                return *new(T), err
×
433
        }
×
434
        defer cursor.Close()
2✔
435
        return cursorResult(cursor)
2✔
436
}
437

438
// FetchOneContext is like FetchOne but additionally requires a context.Context.
439
func (compiledFetch *CompiledFetch[T]) FetchOneContext(ctx context.Context, db DB, params Params) (T, error) {
×
440
        cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1)
×
441
        if err != nil {
×
442
                return *new(T), err
×
443
        }
×
444
        defer cursor.Close()
×
445
        return cursorResult(cursor)
×
446
}
447

448
// FetchAll returns all the results from running the CompiledFetch on the given
449
// DB with the give params.
450
func (compiledFetch *CompiledFetch[T]) FetchAll(db DB, params Params) ([]T, error) {
2✔
451
        cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1)
2✔
452
        if err != nil {
2✔
453
                return nil, err
×
454
        }
×
455
        defer cursor.Close()
2✔
456
        return cursorResults(cursor)
2✔
457
}
458

459
// FetchAllContext is like FetchAll but additionally requires a context.Context.
460
func (compiledFetch *CompiledFetch[T]) FetchAllContext(ctx context.Context, db DB, params Params) ([]T, error) {
×
461
        cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1)
×
462
        if err != nil {
×
463
                return nil, err
×
464
        }
×
465
        defer cursor.Close()
×
466
        return cursorResults(cursor)
×
467
}
468

469
// GetSQL returns a copy of the dialect, query, args, params and rowmapper that
470
// make up the CompiledFetch.
471
func (compiledFetch *CompiledFetch[T]) GetSQL() (dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) {
4✔
472
        dialect = compiledFetch.dialect
4✔
473
        query = compiledFetch.query
4✔
474
        args = make([]any, len(compiledFetch.args))
4✔
475
        params = make(map[string][]int)
4✔
476
        copy(args, compiledFetch.args)
4✔
477
        for name, indexes := range compiledFetch.params {
6✔
478
                indexes2 := make([]int, len(indexes))
2✔
479
                copy(indexes2, indexes)
2✔
480
                params[name] = indexes2
2✔
481
        }
2✔
482
        return dialect, query, args, params, compiledFetch.rowmapper
4✔
483
}
484

485
// Prepare creates a PreparedFetch from a CompiledFetch by preparing it on
486
// the given DB.
487
func (compiledFetch *CompiledFetch[T]) Prepare(db DB) (*PreparedFetch[T], error) {
×
488
        return compiledFetch.PrepareContext(context.Background(), db)
×
489
}
×
490

491
// PrepareContext is like Prepare but additionally requires a context.Context.
492
func (compiledFetch *CompiledFetch[T]) PrepareContext(ctx context.Context, db DB) (*PreparedFetch[T], error) {
4✔
493
        var err error
4✔
494
        preparedFetch := &PreparedFetch[T]{
4✔
495
                compiledFetch: NewCompiledFetch(compiledFetch.GetSQL()),
4✔
496
        }
4✔
497
        if db == nil {
4✔
498
                return nil, fmt.Errorf("db is nil")
×
499
        }
×
500
        preparedFetch.stmt, err = db.PrepareContext(ctx, compiledFetch.query)
4✔
501
        if err != nil {
4✔
502
                return nil, err
×
503
        }
×
504
        preparedFetch.logger, _ = db.(SqLogger)
4✔
505
        if preparedFetch.logger == nil {
4✔
506
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
507
                if logQuery != nil {
×
508
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
509
                        preparedFetch.logger = &sqLogStruct{
×
510
                                logSettings: logSettings,
×
511
                                logQuery:    logQuery,
×
512
                        }
×
513
                }
×
514
        }
515
        return preparedFetch, nil
4✔
516
}
517

518
// PreparedFetch is the result of preparing a CompiledFetch on a DB.
519
type PreparedFetch[T any] struct {
520
        compiledFetch *CompiledFetch[T]
521
        stmt          *sql.Stmt
522
        logger        SqLogger
523
}
524

525
// PrepareFetch returns a new PreparedFetch.
526
func PrepareFetch[T any](db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
4✔
527
        return PrepareFetchContext(context.Background(), db, q, rowmapper)
4✔
528
}
4✔
529

530
// PrepareFetchContext is like PrepareFetch but additionally requires a context.Context.
531
func PrepareFetchContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) {
4✔
532
        compiledFetch, err := CompileFetchContext(ctx, q, rowmapper)
4✔
533
        if err != nil {
4✔
534
                return nil, err
×
535
        }
×
536
        return compiledFetch.PrepareContext(ctx, db)
4✔
537
}
538

539
// FetchCursor returns a new cursor.
540
func (preparedFetch PreparedFetch[T]) FetchCursor(params Params) (*Cursor[T], error) {
×
541
        return preparedFetch.fetchCursor(context.Background(), params, 1)
×
542
}
×
543

544
// FetchCursorContext is like FetchCursor but additionally requires a context.Context.
545
func (preparedFetch PreparedFetch[T]) FetchCursorContext(ctx context.Context, params Params) (*Cursor[T], error) {
×
546
        return preparedFetch.fetchCursor(ctx, params, 1)
×
547
}
×
548

549
func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params Params, skip int) (cursor *Cursor[T], err error) {
4✔
550
        cursor = &Cursor[T]{
4✔
551
                ctx:       ctx,
4✔
552
                rowmapper: preparedFetch.compiledFetch.rowmapper,
4✔
553
                row: &Row{
4✔
554
                        dialect: preparedFetch.compiledFetch.dialect,
4✔
555
                },
4✔
556
                queryStats: QueryStats{
4✔
557
                        Dialect:  preparedFetch.compiledFetch.dialect,
4✔
558
                        Query:    preparedFetch.compiledFetch.query,
4✔
559
                        Args:     preparedFetch.compiledFetch.args,
4✔
560
                        Params:   preparedFetch.compiledFetch.params,
4✔
561
                        RowCount: sql.NullInt64{Valid: true},
4✔
562
                },
4✔
563
                logger: preparedFetch.logger,
4✔
564
        }
4✔
565

4✔
566
        // Call the rowmapper to populate row.scanDest.
4✔
567
        defer mapperFunctionPanicked(&err)
4✔
568
        _ = cursor.rowmapper(cursor.row)
4✔
569
        if err != nil {
4✔
570
                return nil, err
×
571
        }
×
572

573
        // Substitute params.
574
        cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params)
4✔
575
        if err != nil {
4✔
576
                return nil, err
×
577
        }
×
578

579
        // Setup logger.
580
        if cursor.logger != nil {
8✔
581
                cursor.logger.SqLogSettings(ctx, &cursor.logSettings)
4✔
582
                if cursor.logSettings.IncludeCaller {
8✔
583
                        cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1)
4✔
584
                }
4✔
585
        }
586

587
        // Run query.
588
        if cursor.logSettings.IncludeTime {
8✔
589
                cursor.queryStats.StartedAt = time.Now()
4✔
590
        }
4✔
591
        cursor.row.sqlRows, cursor.queryStats.Err = preparedFetch.stmt.QueryContext(ctx, cursor.queryStats.Args...)
4✔
592
        if cursor.logSettings.IncludeTime {
8✔
593
                cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt)
4✔
594
        }
4✔
595
        if cursor.queryStats.Err != nil {
4✔
596
                return nil, cursor.queryStats.Err
×
597
        }
×
598

599
        // Allocate the resultsBuffer.
600
        if cursor.logSettings.IncludeResults > 0 {
6✔
601
                cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer)
2✔
602
                cursor.resultsBuffer.Reset()
2✔
603
        }
2✔
604
        return cursor, nil
4✔
605
}
606

607
// FetchOne returns the first result from running the PreparedFetch with the
608
// give params.
609
func (preparedFetch *PreparedFetch[T]) FetchOne(params Params) (T, error) {
2✔
610
        cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1)
2✔
611
        if err != nil {
2✔
612
                return *new(T), err
×
613
        }
×
614
        defer cursor.Close()
2✔
615
        return cursorResult(cursor)
2✔
616
}
617

618
// FetchOneContext is like FetchOne but additionally requires a context.Context.
619
func (preparedFetch *PreparedFetch[T]) FetchOneContext(ctx context.Context, params Params) (T, error) {
×
620
        cursor, err := preparedFetch.fetchCursor(ctx, params, 1)
×
621
        if err != nil {
×
622
                return *new(T), err
×
623
        }
×
624
        defer cursor.Close()
×
625
        return cursorResult(cursor)
×
626
}
627

628
// FetchAll returns all the results from running the PreparedFetch with the
629
// give params.
630
func (preparedFetch *PreparedFetch[T]) FetchAll(params Params) ([]T, error) {
2✔
631
        cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1)
2✔
632
        if err != nil {
2✔
633
                return nil, err
×
634
        }
×
635
        defer cursor.Close()
2✔
636
        return cursorResults(cursor)
2✔
637
}
638

639
// FetchAllContext is like FetchAll but additionally requires a context.Context.
640
func (preparedFetch *PreparedFetch[T]) FetchAllContext(ctx context.Context, params Params) ([]T, error) {
×
641
        cursor, err := preparedFetch.fetchCursor(ctx, params, 1)
×
642
        if err != nil {
×
643
                return nil, err
×
644
        }
×
645
        defer cursor.Close()
×
646
        return cursorResults(cursor)
×
647
}
648

649
// GetCompiled returns a copy of the underlying CompiledFetch.
650
func (preparedFetch *PreparedFetch[T]) GetCompiled() *CompiledFetch[T] {
×
651
        return NewCompiledFetch(preparedFetch.compiledFetch.GetSQL())
×
652
}
×
653

654
// Close closes the PreparedFetch.
655
func (preparedFetch *PreparedFetch[T]) Close() error {
×
656
        if preparedFetch.stmt == nil {
×
657
                return nil
×
658
        }
×
659
        return preparedFetch.stmt.Close()
×
660
}
661

662
// Exec executes the given Query on the given DB.
663
func Exec(db DB, query Query) (Result, error) {
9✔
664
        return exec(context.Background(), db, query, 1)
9✔
665
}
9✔
666

667
// ExecContext is like Exec but additionally requires a context.Context.
668
func ExecContext(ctx context.Context, db DB, query Query) (Result, error) {
×
669
        return exec(ctx, db, query, 1)
×
670
}
×
671

672
func exec(ctx context.Context, db DB, query Query, skip int) (result Result, err error) {
9✔
673
        if db == nil {
9✔
674
                return result, fmt.Errorf("db is nil")
×
675
        }
×
676
        if query == nil {
9✔
677
                return result, fmt.Errorf("query is nil")
×
678
        }
×
679
        dialect := query.GetDialect()
9✔
680
        if dialect == "" {
9✔
681
                defaultDialect := DefaultDialect.Load()
×
682
                if defaultDialect != nil {
×
683
                        dialect = *defaultDialect
×
684
                }
×
685
        }
686
        queryStats := QueryStats{
9✔
687
                Dialect: dialect,
9✔
688
                Params:  make(map[string][]int),
9✔
689
        }
9✔
690

9✔
691
        // Build query.
9✔
692
        buf := bufpool.Get().(*bytes.Buffer)
9✔
693
        buf.Reset()
9✔
694
        defer bufpool.Put(buf)
9✔
695
        err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params)
9✔
696
        queryStats.Query = buf.String()
9✔
697
        if err != nil {
9✔
698
                return result, err
×
699
        }
×
700

701
        // Setup logger.
702
        var logSettings LogSettings
9✔
703
        logger, _ := db.(SqLogger)
9✔
704
        if logger == nil {
9✔
705
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
706
                if logQuery != nil {
×
707
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
708
                        logger = &sqLogStruct{
×
709
                                logSettings: logSettings,
×
710
                                logQuery:    logQuery,
×
711
                        }
×
712
                }
×
713
        }
714
        if logger != nil {
18✔
715
                logger.SqLogSettings(ctx, &logSettings)
9✔
716
                if logSettings.IncludeCaller {
18✔
717
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
9✔
718
                }
9✔
719
                defer func() {
18✔
720
                        if logSettings.LogAsynchronously {
9✔
721
                                go logger.SqLogQuery(ctx, queryStats)
×
722
                        } else {
9✔
723
                                logger.SqLogQuery(ctx, queryStats)
9✔
724
                        }
9✔
725
                }()
726
        }
727

728
        // Run query.
729
        if logSettings.IncludeTime {
18✔
730
                queryStats.StartedAt = time.Now()
9✔
731
        }
9✔
732
        var sqlResult sql.Result
9✔
733
        sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...)
9✔
734
        if logSettings.IncludeTime {
18✔
735
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
9✔
736
        }
9✔
737
        if queryStats.Err != nil {
9✔
738
                return result, queryStats.Err
×
739
        }
×
740
        return execResult(sqlResult, &queryStats)
9✔
741
}
742

743
// CompiledExec is the result of compiling a Query down into a query string and
744
// args slice. A CompiledExec can be safely executed in parallel.
745
type CompiledExec struct {
746
        dialect string
747
        query   string
748
        args    []any
749
        params  map[string][]int
750
}
751

752
// NewCompiledExec returns a new CompiledExec.
753
func NewCompiledExec(dialect string, query string, args []any, params map[string][]int) *CompiledExec {
1✔
754
        return &CompiledExec{
1✔
755
                dialect: dialect,
1✔
756
                query:   query,
1✔
757
                args:    args,
1✔
758
                params:  params,
1✔
759
        }
1✔
760
}
1✔
761

762
// CompileExec returns a new CompiledExec.
763
func CompileExec(query Query) (*CompiledExec, error) {
1✔
764
        return CompileExecContext(context.Background(), query)
1✔
765
}
1✔
766

767
// CompileExecContext is like CompileExec but additionally requires a context.Context.
768
func CompileExecContext(ctx context.Context, query Query) (*CompiledExec, error) {
2✔
769
        if query == nil {
2✔
770
                return nil, fmt.Errorf("query is nil")
×
771
        }
×
772
        dialect := query.GetDialect()
2✔
773
        if dialect == "" {
2✔
774
                defaultDialect := DefaultDialect.Load()
×
775
                if defaultDialect != nil {
×
776
                        dialect = *defaultDialect
×
777
                }
×
778
        }
779
        compiledExec := &CompiledExec{
2✔
780
                dialect: dialect,
2✔
781
                params:  make(map[string][]int),
2✔
782
        }
2✔
783

2✔
784
        // Build query.
2✔
785
        buf := bufpool.Get().(*bytes.Buffer)
2✔
786
        buf.Reset()
2✔
787
        defer bufpool.Put(buf)
2✔
788
        err := query.WriteSQL(ctx, dialect, buf, &compiledExec.args, compiledExec.params)
2✔
789
        compiledExec.query = buf.String()
2✔
790
        if err != nil {
2✔
791
                return nil, err
×
792
        }
×
793
        return compiledExec, nil
2✔
794
}
795

796
// Exec executes the CompiledExec on the given DB with the given params.
797
func (compiledExec *CompiledExec) Exec(db DB, params Params) (Result, error) {
5✔
798
        return compiledExec.exec(context.Background(), db, params, 1)
5✔
799
}
5✔
800

801
// ExecContext is like Exec but additionally requires a context.Context.
802
func (compiledExec *CompiledExec) ExecContext(ctx context.Context, db DB, params Params) (Result, error) {
×
803
        return compiledExec.exec(ctx, db, params, 1)
×
804
}
×
805

806
func (compiledExec *CompiledExec) exec(ctx context.Context, db DB, params Params, skip int) (result Result, err error) {
5✔
807
        if db == nil {
5✔
808
                return result, fmt.Errorf("db is nil")
×
809
        }
×
810
        queryStats := QueryStats{
5✔
811
                Dialect: compiledExec.dialect,
5✔
812
                Query:   compiledExec.query,
5✔
813
                Args:    compiledExec.args,
5✔
814
                Params:  compiledExec.params,
5✔
815
        }
5✔
816

5✔
817
        // Setup logger.
5✔
818
        var logSettings LogSettings
5✔
819
        logger, _ := db.(SqLogger)
5✔
820
        if logger == nil {
5✔
821
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
822
                if logQuery != nil {
×
823
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
824
                        logger = &sqLogStruct{
×
825
                                logSettings: logSettings,
×
826
                                logQuery:    logQuery,
×
827
                        }
×
828
                }
×
829
        }
830
        if logger != nil {
10✔
831
                logger.SqLogSettings(ctx, &logSettings)
5✔
832
                if logSettings.IncludeCaller {
10✔
833
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
5✔
834
                }
5✔
835
                defer func() {
10✔
836
                        if logSettings.LogAsynchronously {
5✔
837
                                go logger.SqLogQuery(ctx, queryStats)
×
838
                        } else {
5✔
839
                                logger.SqLogQuery(ctx, queryStats)
5✔
840
                        }
5✔
841
                }()
842
        }
843

844
        // Substitute params.
845
        queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params)
5✔
846
        if err != nil {
5✔
847
                return result, err
×
848
        }
×
849

850
        // Run query.
851
        if logSettings.IncludeTime {
10✔
852
                queryStats.StartedAt = time.Now()
5✔
853
        }
5✔
854
        var sqlResult sql.Result
5✔
855
        sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...)
5✔
856
        if logSettings.IncludeTime {
10✔
857
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
5✔
858
        }
5✔
859
        if queryStats.Err != nil {
5✔
860
                return result, queryStats.Err
×
861
        }
×
862
        return execResult(sqlResult, &queryStats)
5✔
863
}
864

865
// GetSQL returns a copy of the dialect, query, args, params and rowmapper that
866
// make up the CompiledExec.
867
func (compiledExec *CompiledExec) GetSQL() (dialect string, query string, args []any, params map[string][]int) {
1✔
868
        dialect = compiledExec.dialect
1✔
869
        query = compiledExec.query
1✔
870
        args = make([]any, len(compiledExec.args))
1✔
871
        params = make(map[string][]int)
1✔
872
        copy(args, compiledExec.args)
1✔
873
        for name, indexes := range compiledExec.params {
5✔
874
                indexes2 := make([]int, len(indexes))
4✔
875
                copy(indexes2, indexes)
4✔
876
                params[name] = indexes2
4✔
877
        }
4✔
878
        return dialect, query, args, params
1✔
879
}
880

881
// Prepare creates a PreparedExec from a CompiledExec by preparing it on the
882
// given DB.
883
func (compiledExec *CompiledExec) Prepare(db DB) (*PreparedExec, error) {
×
884
        return compiledExec.PrepareContext(context.Background(), db)
×
885
}
×
886

887
// PrepareContext is like Prepare but additionally requires a context.Context.
888
func (compiledExec *CompiledExec) PrepareContext(ctx context.Context, db DB) (*PreparedExec, error) {
1✔
889
        var err error
1✔
890
        preparedExec := &PreparedExec{
1✔
891
                compiledExec: NewCompiledExec(compiledExec.GetSQL()),
1✔
892
        }
1✔
893
        preparedExec.stmt, err = db.PrepareContext(ctx, compiledExec.query)
1✔
894
        if err != nil {
1✔
895
                return nil, err
×
896
        }
×
897
        preparedExec.logger, _ = db.(SqLogger)
1✔
898
        if preparedExec.logger == nil {
1✔
899
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
900
                if logQuery != nil {
×
901
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
902
                        preparedExec.logger = &sqLogStruct{
×
903
                                logSettings: logSettings,
×
904
                                logQuery:    logQuery,
×
905
                        }
×
906
                }
×
907
        }
908
        return preparedExec, nil
1✔
909
}
910

911
// PrepareExec is the result of preparing a CompiledExec on a DB.
912
type PreparedExec struct {
913
        compiledExec *CompiledExec
914
        stmt         *sql.Stmt
915
        logger       SqLogger
916
}
917

918
// PrepareExec returns a new PreparedExec.
919
func PrepareExec(db DB, q Query) (*PreparedExec, error) {
1✔
920
        return PrepareExecContext(context.Background(), db, q)
1✔
921
}
1✔
922

923
// PrepareExecContext is like PrepareExec but additionally requires a
924
// context.Context.
925
func PrepareExecContext(ctx context.Context, db DB, q Query) (*PreparedExec, error) {
1✔
926
        compiledExec, err := CompileExecContext(ctx, q)
1✔
927
        if err != nil {
1✔
928
                return nil, err
×
929
        }
×
930
        return compiledExec.PrepareContext(ctx, db)
1✔
931
}
932

933
// Close closes the PreparedExec.
934
func (preparedExec *PreparedExec) Close() error {
×
935
        if preparedExec.stmt == nil {
×
936
                return nil
×
937
        }
×
938
        return preparedExec.stmt.Close()
×
939
}
940

941
// Exec executes the PreparedExec with the given params.
942
func (preparedExec *PreparedExec) Exec(params Params) (Result, error) {
5✔
943
        return preparedExec.exec(context.Background(), params, 1)
5✔
944
}
5✔
945

946
// ExecContext is like Exec but additionally requires a context.Context.
947
func (preparedExec *PreparedExec) ExecContext(ctx context.Context, params Params) (Result, error) {
×
948
        return preparedExec.exec(ctx, params, 1)
×
949
}
×
950

951
func (preparedExec *PreparedExec) exec(ctx context.Context, params Params, skip int) (result Result, err error) {
5✔
952
        queryStats := QueryStats{
5✔
953
                Dialect: preparedExec.compiledExec.dialect,
5✔
954
                Query:   preparedExec.compiledExec.query,
5✔
955
                Args:    preparedExec.compiledExec.args,
5✔
956
                Params:  preparedExec.compiledExec.params,
5✔
957
        }
5✔
958

5✔
959
        // Setup logger.
5✔
960
        var logSettings LogSettings
5✔
961
        if preparedExec.logger != nil {
10✔
962
                preparedExec.logger.SqLogSettings(ctx, &logSettings)
5✔
963
                if logSettings.IncludeCaller {
10✔
964
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
5✔
965
                }
5✔
966
                defer func() {
10✔
967
                        if logSettings.LogAsynchronously {
5✔
968
                                go preparedExec.logger.SqLogQuery(ctx, queryStats)
×
969
                        } else {
5✔
970
                                preparedExec.logger.SqLogQuery(ctx, queryStats)
5✔
971
                        }
5✔
972
                }()
973
        }
974

975
        // Substitute params.
976
        queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params)
5✔
977
        if err != nil {
5✔
978
                return result, err
×
979
        }
×
980

981
        // Run query.
982
        if logSettings.IncludeTime {
10✔
983
                queryStats.StartedAt = time.Now()
5✔
984
        }
5✔
985
        var sqlResult sql.Result
5✔
986
        sqlResult, queryStats.Err = preparedExec.stmt.ExecContext(ctx, queryStats.Args...)
5✔
987
        if logSettings.IncludeTime {
10✔
988
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
5✔
989
        }
5✔
990
        if queryStats.Err != nil {
5✔
991
                return result, queryStats.Err
×
992
        }
×
993
        return execResult(sqlResult, &queryStats)
5✔
994
}
995

996
func getFieldNames(ctx context.Context, row *Row) []string {
10✔
997
        if len(row.fields) == 0 {
13✔
998
                columns, _ := row.sqlRows.Columns()
3✔
999
                return columns
3✔
1000
        }
3✔
1001
        buf := bufpool.Get().(*bytes.Buffer)
7✔
1002
        buf.Reset()
7✔
1003
        defer bufpool.Put(buf)
7✔
1004
        var args []any
7✔
1005
        fieldNames := make([]string, 0, len(row.fields))
7✔
1006
        for _, field := range row.fields {
91✔
1007
                if alias := getAlias(field); alias != "" {
84✔
1008
                        fieldNames = append(fieldNames, alias)
×
1009
                        continue
×
1010
                }
1011
                buf.Reset()
84✔
1012
                args = args[:0]
84✔
1013
                err := field.WriteSQL(ctx, row.dialect, buf, &args, nil)
84✔
1014
                if err != nil {
84✔
1015
                        fieldNames = append(fieldNames, "%!(error="+err.Error()+")")
×
1016
                        continue
×
1017
                }
1018
                fieldName, err := Sprintf(row.dialect, buf.String(), args)
84✔
1019
                if err != nil {
84✔
1020
                        fieldNames = append(fieldNames, "%!(error="+err.Error()+")")
×
1021
                        continue
×
1022
                }
1023
                fieldNames = append(fieldNames, fieldName)
84✔
1024
        }
1025
        return fieldNames
7✔
1026
}
1027

1028
func getFieldMappings(dialect string, fields []Field, scanDest []any) string {
2✔
1029
        var buf bytes.Buffer
2✔
1030
        var args []any
2✔
1031
        var b strings.Builder
2✔
1032
        for i, field := range fields {
5✔
1033
                b.WriteString(fmt.Sprintf("\n %02d. ", i+1))
3✔
1034
                buf.Reset()
3✔
1035
                args = args[:0]
3✔
1036
                err := field.WriteSQL(context.Background(), dialect, &buf, &args, nil)
3✔
1037
                if err != nil {
3✔
1038
                        buf.WriteString("%!(error=" + err.Error() + ")")
×
1039
                        continue
×
1040
                }
1041
                fieldName, err := Sprintf(dialect, buf.String(), args)
3✔
1042
                if err != nil {
3✔
1043
                        b.WriteString("%!(error=" + err.Error() + ")")
×
1044
                        continue
×
1045
                }
1046
                b.WriteString(fieldName + " => " + reflect.TypeOf(scanDest[i]).String())
3✔
1047
        }
1048
        return b.String()
2✔
1049
}
1050

1051
func cursorResult[T any](cursor *Cursor[T]) (result T, err error) {
18✔
1052
        for cursor.Next() {
36✔
1053
                result, err = cursor.Result()
18✔
1054
                if err != nil {
18✔
1055
                        return result, err
×
1056
                }
×
1057
                break
18✔
1058
        }
1059
        if cursor.RowCount() == 0 {
18✔
1060
                return result, sql.ErrNoRows
×
1061
        }
×
1062
        return result, cursor.Close()
18✔
1063
}
1064

1065
func cursorResults[T any](cursor *Cursor[T]) (results []T, err error) {
18✔
1066
        var result T
18✔
1067
        for cursor.Next() {
84✔
1068
                result, err = cursor.Result()
66✔
1069
                if err != nil {
66✔
1070
                        return results, err
×
1071
                }
×
1072
                results = append(results, result)
66✔
1073
        }
1074
        return results, cursor.Close()
18✔
1075
}
1076

1077
func execResult(sqlResult sql.Result, queryStats *QueryStats) (Result, error) {
19✔
1078
        var err error
19✔
1079
        var result Result
19✔
1080
        if queryStats.Dialect == DialectSQLite || queryStats.Dialect == DialectMySQL {
34✔
1081
                result.LastInsertId, err = sqlResult.LastInsertId()
15✔
1082
                if err != nil {
15✔
1083
                        return result, err
×
1084
                }
×
1085
                queryStats.LastInsertId.Valid = true
15✔
1086
                queryStats.LastInsertId.Int64 = result.LastInsertId
15✔
1087
        }
1088
        result.RowsAffected, err = sqlResult.RowsAffected()
19✔
1089
        if err != nil {
19✔
1090
                return result, err
×
1091
        }
×
1092
        queryStats.RowsAffected.Valid = true
19✔
1093
        queryStats.RowsAffected.Int64 = result.RowsAffected
19✔
1094
        return result, nil
19✔
1095
}
1096

1097
// FetchExists returns a boolean indicating if running the given Query on the
1098
// given DB returned any results.
1099
func FetchExists(db DB, query Query) (exists bool, err error) {
4✔
1100
        return fetchExists(context.Background(), db, query, 1)
4✔
1101
}
4✔
1102

1103
// FetchExistsContext is like FetchExists but additionally requires a
1104
// context.Context.
1105
func FetchExistsContext(ctx context.Context, db DB, query Query) (exists bool, err error) {
×
1106
        return fetchExists(ctx, db, query, 1)
×
1107
}
×
1108

1109
func fetchExists(ctx context.Context, db DB, query Query, skip int) (exists bool, err error) {
4✔
1110
        dialect := query.GetDialect()
4✔
1111
        if dialect == "" {
4✔
1112
                defaultDialect := DefaultDialect.Load()
×
1113
                if defaultDialect != nil {
×
1114
                        dialect = *defaultDialect
×
1115
                }
×
1116
        }
1117
        queryStats := QueryStats{
4✔
1118
                Dialect: dialect,
4✔
1119
                Exists:  sql.NullBool{Valid: true},
4✔
1120
                Params:  make(map[string][]int),
4✔
1121
        }
4✔
1122

4✔
1123
        // Build query.
4✔
1124
        buf := bufpool.Get().(*bytes.Buffer)
4✔
1125
        buf.Reset()
4✔
1126
        defer bufpool.Put(buf)
4✔
1127
        if dialect == DialectSQLServer {
5✔
1128
                query = Queryf("SELECT CASE WHEN EXISTS ({}) THEN 1 ELSE 0 END", query)
1✔
1129
        } else {
4✔
1130
                query = Queryf("SELECT EXISTS ({})", query)
3✔
1131
        }
3✔
1132
        err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params)
4✔
1133
        queryStats.Query = buf.String()
4✔
1134
        if err != nil {
4✔
1135
                return false, err
×
1136
        }
×
1137

1138
        // Setup logger.
1139
        var logSettings LogSettings
4✔
1140
        logger, _ := db.(SqLogger)
4✔
1141
        if logger == nil {
4✔
1142
                logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats))
×
1143
                if logQuery != nil {
×
1144
                        logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings))
×
1145
                        logger = &sqLogStruct{
×
1146
                                logSettings: logSettings,
×
1147
                                logQuery:    logQuery,
×
1148
                        }
×
1149
                }
×
1150
        }
1151
        if logger != nil {
8✔
1152
                logger.SqLogSettings(ctx, &logSettings)
4✔
1153
                if logSettings.IncludeCaller {
8✔
1154
                        queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1)
4✔
1155
                }
4✔
1156
                defer func() {
8✔
1157
                        if logSettings.LogAsynchronously {
4✔
1158
                                go logger.SqLogQuery(ctx, queryStats)
×
1159
                        } else {
4✔
1160
                                logger.SqLogQuery(ctx, queryStats)
4✔
1161
                        }
4✔
1162
                }()
1163
        }
1164

1165
        // Run query.
1166
        if logSettings.IncludeTime {
8✔
1167
                queryStats.StartedAt = time.Now()
4✔
1168
        }
4✔
1169
        var sqlRows *sql.Rows
4✔
1170
        sqlRows, queryStats.Err = db.QueryContext(ctx, queryStats.Query, queryStats.Args...)
4✔
1171
        if logSettings.IncludeTime {
8✔
1172
                queryStats.TimeTaken = time.Since(queryStats.StartedAt)
4✔
1173
        }
4✔
1174
        if queryStats.Err != nil {
4✔
1175
                return false, queryStats.Err
×
1176
        }
×
1177

1178
        for sqlRows.Next() {
8✔
1179
                err = sqlRows.Scan(&exists)
4✔
1180
                if err != nil {
4✔
1181
                        return false, err
×
1182
                }
×
1183
                break
4✔
1184
        }
1185
        queryStats.Exists.Bool = exists
4✔
1186

4✔
1187
        if err := sqlRows.Close(); err != nil {
4✔
1188
                return exists, err
×
1189
        }
×
1190
        if err := sqlRows.Err(); err != nil {
4✔
1191
                return exists, err
×
1192
        }
×
1193
        return exists, nil
4✔
1194
}
1195

1196
// substituteParams will return a new args slice by substituting values from
1197
// the given paramValues. The input args slice is untouched.
1198
func substituteParams(dialect string, args []any, paramIndexes map[string][]int, paramValues map[string]any) ([]any, error) {
21✔
1199
        if len(paramValues) == 0 {
26✔
1200
                return args, nil
5✔
1201
        }
5✔
1202
        newArgs := make([]any, len(args))
16✔
1203
        copy(newArgs, args)
16✔
1204
        var err error
16✔
1205
        for name, value := range paramValues {
65✔
1206
                indexes := paramIndexes[name]
49✔
1207
                for _, index := range indexes {
98✔
1208
                        switch arg := newArgs[index].(type) {
49✔
1209
                        case sql.NamedArg:
46✔
1210
                                arg.Value, err = preprocessValue(dialect, value)
46✔
1211
                                if err != nil {
46✔
1212
                                        return nil, err
×
1213
                                }
×
1214
                                newArgs[index] = arg
46✔
1215
                        default:
3✔
1216
                                value, err = preprocessValue(dialect, value)
3✔
1217
                                if err != nil {
3✔
1218
                                        return nil, err
×
1219
                                }
×
1220
                                newArgs[index] = value
3✔
1221
                        }
1222
                }
1223
        }
1224
        return newArgs, nil
16✔
1225
}
1226

1227
func caller(skip int) (file string, line int, function string) {
39✔
1228
        pc, file, line, _ := runtime.Caller(skip + 1)
39✔
1229
        fn := runtime.FuncForPC(pc)
39✔
1230
        function = fn.Name()
39✔
1231
        return file, line, function
39✔
1232
}
39✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc