• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stephenafamo / bob / 14640657403

24 Apr 2025 11:38AM UTC coverage: 44.274% (-3.6%) from 47.872%
14640657403

push

github

web-flow
Merge pull request #391 from stephenafamo/queries

Implement parsing of Postgres SELECT queries

66 of 1446 new or added lines in 23 files covered. (4.56%)

4 existing lines in 3 files now uncovered.

7481 of 16897 relevant lines covered (44.27%)

222.34 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.1
/gen/bobgen-sqlite/driver/sqlite.go
1
package driver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "errors"
7
        "fmt"
8
        "net/url"
9
        "slices"
10
        "sort"
11
        "strconv"
12
        "strings"
13

14
        "github.com/aarondl/opt/null"
15
        helpers "github.com/stephenafamo/bob/gen/bobgen-helpers"
16
        "github.com/stephenafamo/bob/gen/bobgen-sqlite/driver/parser"
17
        "github.com/stephenafamo/bob/gen/drivers"
18
        "github.com/stephenafamo/scan"
19
        "github.com/stephenafamo/scan/stdscan"
20
        _ "github.com/tursodatabase/libsql-client-go/libsql"
21
        "github.com/volatiletech/strmangle"
22
        _ "modernc.org/sqlite"
23
)
24

25
type (
26
        Interface  = drivers.Interface[any, any, IndexExtra]
27
        DBInfo     = drivers.DBInfo[any, any, IndexExtra]
28
        IndexExtra = parser.IndexExtra
29
)
30

31
func New(config Config) Interface {
26✔
32
        if config.DriverName == "" {
52✔
33
                config.DriverName = "modernc.org/sqlite"
26✔
34
        }
26✔
35
        return &driver{config: config}
26✔
36
}
37

38
type Config struct {
39
        // The database connection string
40
        DSN string
41
        // The database schemas to generate models for
42
        // a map of the schema name to the DSN
43
        Attach map[string]string
44
        // Folders containing query files
45
        Queries []string `yaml:"queries"`
46
        // The name of this schema will not be included in the generated models
47
        // a context value can then be used to set the schema at runtime
48
        // useful for multi-tenant setups
49
        SharedSchema string `yaml:"shared_schema"`
50
        // List of tables that will be included. Others are ignored
51
        Only map[string][]string
52
        // List of tables that will be should be ignored. Others are included
53
        Except map[string][]string
54
        // Which `database/sql` driver to use (the full module name)
55
        DriverName string `yaml:"driver_name"`
56

57
        // Used in main.go
58

59
        // The name of the folder to output the models package to
60
        Output string
61
        // The name you wish to assign to your generated models package
62
        Pkgname   string
63
        NoFactory bool `yaml:"no_factory"`
64
}
65

66
// driver holds the database connection string and a handle
67
// to the database connection.
68
type driver struct {
69
        config Config
70
        conn   *sql.DB
71
}
72

73
func (d *driver) Dialect() string {
24✔
74
        return "sqlite"
24✔
75
}
24✔
76

77
func (d *driver) Destination() string {
×
78
        return d.config.Output
×
79
}
×
80

81
func (d *driver) PackageName() string {
×
82
        return d.config.Pkgname
×
83
}
×
84

85
func (d *driver) Types() drivers.Types {
12✔
86
        return helpers.Types()
12✔
87
}
12✔
88

89
// Assemble all the information we need to provide back to the driver
90
func (d *driver) Assemble(ctx context.Context) (*DBInfo, error) {
26✔
91
        var err error
26✔
92

26✔
93
        if d.config.SharedSchema == "" {
52✔
94
                d.config.SharedSchema = "main"
26✔
95
        }
26✔
96

97
        if d.config.DSN == "" {
26✔
98
                return nil, fmt.Errorf("database dsn is not set")
×
99
        }
×
100

101
        driverName := d.inferDriver()
26✔
102
        d.conn, err = sql.Open(driverName, d.config.DSN)
26✔
103
        if err != nil {
26✔
104
                return nil, fmt.Errorf("failed to connect to database: %w", err)
×
105
        }
×
106
        defer d.conn.Close()
26✔
107

26✔
108
        for schema, dsn := range d.config.Attach {
52✔
109
                if driverName == "sqlite" {
40✔
110
                        dsn = strconv.Quote(dsn)
14✔
111
                }
14✔
112
                _, err = d.conn.ExecContext(ctx, fmt.Sprintf("attach database %s as %s", dsn, schema))
26✔
113
                if err != nil {
26✔
114
                        return nil, fmt.Errorf("could not attach %q: %w", schema, err)
×
115
                }
×
116
        }
117

118
        tables, err := d.tables(ctx)
26✔
119
        if err != nil {
26✔
120
                return nil, fmt.Errorf("getting tables: %w", err)
×
121
        }
×
122

123
        queries, err := drivers.ParseFolders(ctx, parser.New(tables), d.config.Queries...)
26✔
124
        if err != nil {
26✔
125
                return nil, fmt.Errorf("parse query folders: %w", err)
×
126
        }
×
127

128
        if driverName == "libsql" {
38✔
129
                d.config.DriverName = "github.com/tursodatabase/libsql-client-go/libsql"
12✔
130
        }
12✔
131
        dbinfo := &DBInfo{
26✔
132
                DriverName:   d.config.DriverName,
26✔
133
                Tables:       tables,
26✔
134
                QueryFolders: queries,
26✔
135
        }
26✔
136

26✔
137
        return dbinfo, nil
26✔
138
}
139

140
func (d *driver) inferDriver() string {
26✔
141
        driverName := "sqlite"
26✔
142
        if !strings.Contains(d.config.DSN, "://") {
40✔
143
                return driverName
14✔
144
        }
14✔
145
        dsn, _ := url.Parse(d.config.DSN)
12✔
146
        if dsn == nil {
12✔
147
                return driverName
×
148
        }
×
149
        libsqlSchemes := map[string]bool{
12✔
150
                "libsql": true,
12✔
151
                "file":   true,
12✔
152
                "https":  true,
12✔
153
                "http":   true,
12✔
154
                "wss":    true,
12✔
155
                "ws":     true,
12✔
156
        }
12✔
157
        if libsqlSchemes[dsn.Scheme] {
24✔
158
                driverName = "libsql"
12✔
159
        }
12✔
160
        return driverName
12✔
161
}
162

163
func (d *driver) buildQuery(schema string) (string, []any) {
52✔
164
        var args []any
52✔
165
        query := fmt.Sprintf(`SELECT name FROM %q.sqlite_schema WHERE name NOT LIKE 'sqlite_%%' AND type IN ('table', 'view')`, schema)
52✔
166

52✔
167
        tableFilter := drivers.ParseTableFilter(d.config.Only, d.config.Except)
52✔
168

52✔
169
        if len(tableFilter.Only) > 0 {
84✔
170
                var subqueries []string
32✔
171
                stringPatterns, regexPatterns := tableFilter.ClassifyPatterns(tableFilter.Only)
32✔
172
                include := make([]string, 0, len(stringPatterns))
32✔
173
                for _, name := range stringPatterns {
128✔
174
                        if (schema == "main" && !strings.Contains(name, ".")) || strings.HasPrefix(name, schema+".") {
144✔
175
                                include = append(include, strings.TrimPrefix(name, schema+"."))
48✔
176
                        }
48✔
177
                }
178
                if len(include) > 0 {
56✔
179
                        subqueries = append(subqueries, fmt.Sprintf("name in (%s)", strmangle.Placeholders(false, len(include), 1, 1)))
24✔
180
                        for _, w := range include {
72✔
181
                                args = append(args, w)
48✔
182
                        }
48✔
183
                }
184
                if len(regexPatterns) > 0 {
48✔
185
                        subqueries = append(subqueries, fmt.Sprintf("name regexp (%s)", strmangle.Placeholders(false, 1, len(args)+1, 1)))
16✔
186
                        args = append(args, strings.Join(regexPatterns, "|"))
16✔
187
                }
16✔
188
                if len(subqueries) > 0 {
64✔
189
                        query += fmt.Sprintf(" and (%s)", strings.Join(subqueries, " or "))
32✔
190
                }
32✔
191
        }
192

193
        if len(tableFilter.Except) > 0 {
84✔
194
                var subqueries []string
32✔
195
                stringPatterns, regexPatterns := tableFilter.ClassifyPatterns(tableFilter.Except)
32✔
196
                exclude := make([]string, 0, len(tableFilter.Except))
32✔
197
                for _, name := range stringPatterns {
128✔
198
                        if (schema == "main" && !strings.Contains(name, ".")) || strings.HasPrefix(name, schema+".") {
144✔
199
                                exclude = append(exclude, strings.TrimPrefix(name, schema+"."))
48✔
200
                        }
48✔
201
                }
202
                if len(exclude) > 0 {
56✔
203
                        subqueries = append(subqueries, fmt.Sprintf("name not in (%s)", strmangle.Placeholders(false, len(exclude), 1+len(args), 1)))
24✔
204
                        for _, w := range exclude {
72✔
205
                                args = append(args, w)
48✔
206
                        }
48✔
207
                }
208
                if len(regexPatterns) > 0 {
48✔
209
                        subqueries = append(subqueries, fmt.Sprintf("name not regexp (%s)", strmangle.Placeholders(false, 1, len(args)+1, 1)))
16✔
210
                        args = append(args, strings.Join(regexPatterns, "|"))
16✔
211
                }
16✔
212
                if len(subqueries) > 0 {
64✔
213
                        query += fmt.Sprintf(" and (%s)", strings.Join(subqueries, " and "))
32✔
214
                }
32✔
215
        }
216

217
        query += ` ORDER BY type, name`
52✔
218

52✔
219
        return query, args
52✔
220
}
221

222
func (d *driver) tables(ctx context.Context) (drivers.Tables[any, IndexExtra], error) {
26✔
223
        mainQuery, mainArgs := d.buildQuery("main")
26✔
224
        mainTables, err := stdscan.All(ctx, d.conn, scan.SingleColumnMapper[string], mainQuery, mainArgs...)
26✔
225
        if err != nil {
26✔
226
                return nil, err
×
227
        }
×
228

229
        colFilter := drivers.ParseColumnFilter(mainTables, d.config.Only, d.config.Except)
26✔
230
        allTables := make(drivers.Tables[any, IndexExtra], len(mainTables))
26✔
231
        for i, name := range mainTables {
206✔
232
                allTables[i], err = d.getTable(ctx, "main", name, colFilter)
180✔
233
                if err != nil {
180✔
234
                        return nil, err
×
235
                }
×
236
        }
237

238
        for schema := range d.config.Attach {
52✔
239
                schemaQuery, schemaArgs := d.buildQuery(schema)
26✔
240
                tables, err := stdscan.All(ctx, d.conn, scan.SingleColumnMapper[string], schemaQuery, schemaArgs...)
26✔
241
                if err != nil {
26✔
242
                        return nil, err
×
243
                }
×
244
                colFilter = drivers.ParseColumnFilter(tables, d.config.Only, d.config.Except)
26✔
245
                for _, name := range tables {
186✔
246
                        table, err := d.getTable(ctx, schema, name, colFilter)
160✔
247
                        if err != nil {
160✔
248
                                return nil, err
×
249
                        }
×
250
                        allTables = append(allTables, table)
160✔
251
                }
252
        }
253

254
        return allTables, nil
26✔
255
}
256

257
func (d driver) getTable(ctx context.Context, schema, name string, colFilter drivers.ColumnFilter) (drivers.Table[any, IndexExtra], error) {
340✔
258
        var err error
340✔
259

340✔
260
        table := drivers.Table[any, IndexExtra]{
340✔
261
                Key:    d.key(schema, name),
340✔
262
                Schema: d.schema(schema),
340✔
263
                Name:   name,
340✔
264
        }
340✔
265

340✔
266
        tinfo, err := d.tableInfo(ctx, schema, name)
340✔
267
        if err != nil {
340✔
268
                return table, err
×
269
        }
×
270

271
        table.Columns, err = d.columns(ctx, schema, name, tinfo, colFilter)
340✔
272
        if err != nil {
340✔
273
                return table, err
×
274
        }
×
275

276
        // We cannot rely on the indexes to get the primary key
277
        // because it is not always included in the indexes
278
        table.Constraints.Primary = d.primaryKey(schema, name, tinfo)
340✔
279
        table.Constraints.Foreign, err = d.foreignKeys(ctx, schema, name)
340✔
280
        if err != nil {
340✔
281
                return table, err
×
282
        }
×
283

284
        table.Indexes, err = d.indexes(ctx, schema, name)
340✔
285
        if err != nil {
340✔
286
                return table, err
×
287
        }
×
288

289
        // Get Unique constraints from indexes
290
        // Also check if the primary key is in the indexes
291
        hasPk := false
340✔
292
        for _, index := range table.Indexes {
550✔
293
                constraint := drivers.Constraint[any]{
210✔
294
                        Name:    index.Name,
210✔
295
                        Columns: index.NonExpressionColumns(),
210✔
296
                }
210✔
297

210✔
298
                switch index.Type {
210✔
299
                case "pk":
110✔
300
                        hasPk = true
110✔
301
                case "u":
40✔
302
                        table.Constraints.Uniques = append(table.Constraints.Uniques, constraint)
40✔
303
                }
304
        }
305

306
        // Add the primary key to the indexes if it is not already there
307
        if !hasPk && table.Constraints.Primary != nil {
540✔
308
                pkIndex := drivers.Index[IndexExtra]{
200✔
309
                        Type:    "pk",
200✔
310
                        Name:    table.Constraints.Primary.Name,
200✔
311
                        Columns: make([]drivers.IndexColumn, len(table.Constraints.Primary.Columns)),
200✔
312
                        Unique:  true,
200✔
313
                }
200✔
314

200✔
315
                for i, col := range table.Constraints.Primary.Columns {
400✔
316
                        pkIndex.Columns[i] = drivers.IndexColumn{
200✔
317
                                Name: col,
200✔
318
                        }
200✔
319
                }
200✔
320

321
                // List the primary key first
322
                table.Indexes = append([]drivers.Index[IndexExtra]{pkIndex}, table.Indexes...)
200✔
323
        }
324

325
        return table, nil
340✔
326
}
327

328
// Columns takes a table name and attempts to retrieve the table information
329
// from the database. It retrieves the column names
330
// and column types and returns those as a []Column after TranslateColumnType()
331
// converts the SQL types to Go types, for example: "varchar" to "string"
332
func (d driver) columns(ctx context.Context, schema, tableName string, tinfo []info, colFilter drivers.ColumnFilter) ([]drivers.Column, error) {
340✔
333
        var columns []drivers.Column //nolint:prealloc
340✔
334

340✔
335
        //nolint:gosec
340✔
336
        query := fmt.Sprintf("SELECT 1 FROM '%s'.sqlite_master WHERE type = 'table' AND name = ? AND sql LIKE '%%AUTOINCREMENT%%'", schema)
340✔
337
        result, err := d.conn.QueryContext(ctx, query, tableName)
340✔
338
        if err != nil {
340✔
339
                return nil, fmt.Errorf("autoincr query: %w", err)
×
340
        }
×
341
        tableHasAutoIncr := result.Next()
340✔
342
        if err := result.Close(); err != nil {
340✔
343
                return nil, err
×
344
        }
×
345

346
        nPkeys := 0
340✔
347
        for _, column := range tinfo {
2,190✔
348
                if column.Pk != 0 {
2,180✔
349
                        nPkeys++
330✔
350
                }
330✔
351
        }
352

353
        filter := colFilter[tableName]
340✔
354
        excludedColumns := make(map[string]struct{}, len(filter.Except))
340✔
355
        if len(filter.Except) > 0 {
444✔
356
                for _, w := range filter.Except {
208✔
357
                        excludedColumns[w] = struct{}{}
104✔
358
                }
104✔
359
        }
360

361
        for _, colInfo := range tinfo {
2,190✔
362
                if _, ok := excludedColumns[colInfo.Name]; ok {
1,874✔
363
                        continue
24✔
364
                }
365
                column := drivers.Column{
1,826✔
366
                        Name:     colInfo.Name,
1,826✔
367
                        DBType:   strings.ToUpper(colInfo.Type),
1,826✔
368
                        Nullable: !colInfo.NotNull && colInfo.Pk < 1,
1,826✔
369
                }
1,826✔
370

1,826✔
371
                isPrimaryKeyInteger := colInfo.Pk == 1 && column.DBType == "INTEGER"
1,826✔
372
                // This is special behavior noted in the sqlite documentation.
1,826✔
373
                // An integer primary key becomes synonymous with the internal ROWID
1,826✔
374
                // and acts as an auto incrementing value. Although there's important
1,826✔
375
                // differences between using the keyword AUTOINCREMENT and this inferred
1,826✔
376
                // version, they don't matter here so just masquerade as the same thing as
1,826✔
377
                // above.
1,826✔
378
                autoIncr := isPrimaryKeyInteger && (tableHasAutoIncr || nPkeys == 1)
1,826✔
379

1,826✔
380
                // See: https://github.com/sqlite/sqlite/blob/91f621531dc1cb9ba5f6a47eb51b1de9ed8bdd07/src/pragma.c#L1165
1,826✔
381
                column.Generated = colInfo.Hidden == 2 || colInfo.Hidden == 3
1,826✔
382

1,826✔
383
                if colInfo.DefaultValue.Valid {
2,166✔
384
                        column.Default = colInfo.DefaultValue.String
340✔
385
                } else if autoIncr {
2,026✔
386
                        column.Default = "auto_increment"
200✔
387
                } else if column.Generated {
1,526✔
388
                        column.Default = "auto_generated"
40✔
389
                }
40✔
390

391
                if column.Nullable && column.Default == "" {
2,476✔
392
                        column.Default = "NULL"
650✔
393
                }
650✔
394

395
                column.Type = parser.TranslateColumnType(column.DBType)
1,826✔
396
                columns = append(columns, column)
1,826✔
397
        }
398

399
        return columns, nil
340✔
400
}
401

402
func (s driver) tableInfo(ctx context.Context, schema, tableName string) ([]info, error) {
340✔
403
        var ret []info
340✔
404
        rows, err := s.conn.QueryContext(ctx, fmt.Sprintf("PRAGMA '%s'.table_xinfo('%s')", schema, tableName))
340✔
405
        if err != nil {
340✔
406
                return nil, err
×
407
        }
×
408
        defer rows.Close()
340✔
409

340✔
410
        for rows.Next() {
2,190✔
411
                tinfo := info{}
1,850✔
412
                if err := rows.Scan(&tinfo.Cid, &tinfo.Name, &tinfo.Type, &tinfo.NotNull, &tinfo.DefaultValue, &tinfo.Pk, &tinfo.Hidden); err != nil {
1,850✔
413
                        return nil, fmt.Errorf("unable to scan for table %s: %w", tableName, err)
×
414
                }
×
415

416
                ret = append(ret, tinfo)
1,850✔
417
        }
418
        return ret, nil
340✔
419
}
420

421
// primaryKey looks up the primary key for a table.
422
func (s driver) primaryKey(schema, tableName string, tinfo []info) *drivers.Constraint[any] {
340✔
423
        var cols []string
340✔
424
        for _, c := range tinfo {
2,190✔
425
                if c.Pk > 0 {
2,180✔
426
                        cols = append(cols, c.Name)
330✔
427
                }
330✔
428
        }
429

430
        if len(cols) == 0 {
370✔
431
                return nil
30✔
432
        }
30✔
433

434
        return &drivers.Constraint[any]{
310✔
435
                Name:    fmt.Sprintf("pk_%s_%s", schema, tableName),
310✔
436
                Columns: cols,
310✔
437
        }
310✔
438
}
439

440
func (d driver) skipKey(table, column string) bool {
110✔
441
        if len(d.config.Only) > 0 {
110✔
442
                // check if the table is listed at all
×
443
                filter, ok := d.config.Only[table]
×
444
                if !ok {
×
445
                        return true
×
446
                }
×
447

448
                // check if the column is listed
449
                if len(filter) == 0 {
×
450
                        return false
×
451
                }
×
452

NEW
453
                return !slices.Contains(filter, column)
×
454
        }
455

456
        if len(d.config.Except) > 0 {
154✔
457
                filter, ok := d.config.Except[table]
44✔
458
                if !ok {
88✔
459
                        return false
44✔
460
                }
44✔
461

462
                if len(filter) == 0 {
×
463
                        return true
×
464
                }
×
465

NEW
466
                if slices.Contains(filter, column) {
×
NEW
467
                        return true
×
UNCOV
468
                }
×
469
        }
470

471
        return false
66✔
472
}
473

474
// foreignKeys retrieves the foreign keys for a given table name.
475
func (d driver) foreignKeys(ctx context.Context, schema, tableName string) ([]drivers.ForeignKey[any], error) {
340✔
476
        rows, err := d.conn.QueryContext(ctx, fmt.Sprintf("PRAGMA '%s'.foreign_key_list('%s')", schema, tableName))
340✔
477
        if err != nil {
340✔
478
                return nil, err
×
479
        }
×
480
        defer rows.Close()
340✔
481

340✔
482
        fkeyMap := make(map[int]drivers.ForeignKey[any])
340✔
483
        for rows.Next() {
450✔
484
                var id, seq int
110✔
485
                var ftable, col string
110✔
486
                var fcolNullable null.Val[string]
110✔
487

110✔
488
                // not used
110✔
489
                var onupdate, ondelete, match string
110✔
490

110✔
491
                err = rows.Scan(&id, &seq, &ftable, &col, &fcolNullable, &onupdate, &ondelete, &match)
110✔
492
                if err != nil {
110✔
493
                        return nil, err
×
494
                }
×
495

496
                fullFtable := ftable
110✔
497
                if schema != "main" {
150✔
498
                        fullFtable = fmt.Sprintf("%s.%s", schema, ftable)
40✔
499
                }
40✔
500

501
                fcol, _ := fcolNullable.Get()
110✔
502
                if fcol == "" {
120✔
503
                        fcol, err = stdscan.One(
10✔
504
                                ctx, d.conn, scan.SingleColumnMapper[string],
10✔
505
                                fmt.Sprintf("SELECT name FROM pragma_table_info('%s', '%s') WHERE pk = ?", ftable, schema), seq+1,
10✔
506
                        )
10✔
507
                        if err != nil {
10✔
508
                                return nil, fmt.Errorf("could not find column %q in table %q: %w", col, ftable, err)
×
509
                        }
×
510
                }
511

512
                if d.skipKey(fullFtable, fcol) {
110✔
513
                        continue
×
514
                }
515

516
                fkeyMap[id] = drivers.ForeignKey[any]{
110✔
517
                        Constraint: drivers.Constraint[any]{
110✔
518
                                Name:    fmt.Sprintf("fk_%s_%d", tableName, id),
110✔
519
                                Columns: append(fkeyMap[id].Columns, col),
110✔
520
                        },
110✔
521
                        ForeignTable:   d.key(schema, ftable),
110✔
522
                        ForeignColumns: append(fkeyMap[id].ForeignColumns, fcol),
110✔
523
                }
110✔
524
        }
525

526
        if err = rows.Err(); err != nil {
340✔
527
                return nil, err
×
528
        }
×
529

530
        fkeys := make([]drivers.ForeignKey[any], 0, len(fkeyMap))
340✔
531

340✔
532
        for _, fkey := range fkeyMap {
440✔
533
                fkeys = append(fkeys, fkey)
100✔
534
        }
100✔
535

536
        sort.Slice(fkeys, func(i, j int) bool {
380✔
537
                return fkeys[i].Name < fkeys[j].Name
40✔
538
        })
40✔
539

540
        return fkeys, nil
340✔
541
}
542

543
// uniques retrieves the unique keys for a given table name.
544

545
type info struct {
546
        Cid          int
547
        Name         string
548
        Type         string
549
        NotNull      bool
550
        DefaultValue sql.NullString
551
        Pk           int
552
        Hidden       int
553
}
554

555
func (d *driver) key(schema string, table string) string {
450✔
556
        key := table
450✔
557
        if schema != "" && schema != d.config.SharedSchema {
650✔
558
                key = schema + "." + table
200✔
559
        }
200✔
560

561
        return key
450✔
562
}
563

564
func (d *driver) schema(schema string) string {
340✔
565
        if schema == d.config.SharedSchema {
520✔
566
                return ""
180✔
567
        }
180✔
568

569
        return schema
160✔
570
}
571

572
func (d *driver) indexes(ctx context.Context, schema, tableName string) ([]drivers.Index[IndexExtra], error) {
340✔
573
        query := fmt.Sprintf(`
340✔
574
        SELECT name, "unique", origin, partial
340✔
575
        FROM pragma_index_list('%s', '%s') ORDER BY seq ASC
340✔
576
        `, tableName, schema)
340✔
577
        indexNames, err := stdscan.All(ctx, d.conn, scan.StructMapper[struct {
340✔
578
                Name    string
340✔
579
                Unique  bool
340✔
580
                Origin  string
340✔
581
                Partial bool
340✔
582
        }](), query)
340✔
583
        if err != nil {
340✔
584
                return nil, err
×
585
        }
×
586

587
        indexes := make([]drivers.Index[IndexExtra], len(indexNames))
340✔
588
        for i, index := range indexNames {
550✔
589
                cols, err := d.getIndexInformation(ctx, schema, tableName, index.Name)
210✔
590
                if err != nil {
210✔
591
                        return nil, err
×
592
                }
×
593
                indexes[i] = drivers.Index[IndexExtra]{
210✔
594
                        Type:    index.Origin,
210✔
595
                        Name:    index.Name,
210✔
596
                        Unique:  index.Unique,
210✔
597
                        Columns: cols,
210✔
598
                        Extra: IndexExtra{
210✔
599
                                Partial: index.Partial,
210✔
600
                        },
210✔
601
                }
210✔
602

603
        }
604

605
        return indexes, nil
340✔
606
}
607

608
func (d *driver) getIndexInformation(ctx context.Context, schema, tableName, indexName string) ([]drivers.IndexColumn, error) {
210✔
609
        colExpressions, err := d.extractIndexExpressions(ctx, schema, tableName, indexName)
210✔
610
        if err != nil {
210✔
611
                return nil, err
×
612
        }
×
613

614
        query := fmt.Sprintf(`
210✔
615
            SELECT seqno, name, desc
210✔
616
            FROM pragma_index_xinfo('%s', '%s')
210✔
617
            WHERE key = 1
210✔
618
            ORDER BY seqno ASC`,
210✔
619
                indexName, schema)
210✔
620

210✔
621
        var columns []drivers.IndexColumn //nolint:prealloc
210✔
622
        for column, err := range stdscan.Each(ctx, d.conn, scan.StructMapper[struct {
210✔
623
                Seqno int
210✔
624
                Name  sql.NullString
210✔
625
                Desc  bool
210✔
626
        }](), query) {
480✔
627
                if err != nil {
270✔
628
                        return nil, err
×
629
                }
×
630

631
                col := drivers.IndexColumn{
270✔
632
                        Name: column.Name.String,
270✔
633
                        Desc: column.Desc,
270✔
634
                }
270✔
635

270✔
636
                if !column.Name.Valid {
310✔
637
                        col.Name = colExpressions[column.Seqno]
40✔
638
                        col.IsExpression = true
40✔
639
                }
40✔
640

641
                columns = append(columns, col)
270✔
642
        }
643

644
        return columns, nil
210✔
645
}
646

647
func (d driver) extractIndexExpressions(ctx context.Context, schema, tableName, indexName string) ([]string, error) {
210✔
648
        var nullDDL sql.NullString
210✔
649

210✔
650
        //nolint:gosec
210✔
651
        query := fmt.Sprintf("SELECT sql FROM '%s'.sqlite_master WHERE type = 'index' AND name = ? AND tbl_name = ?", schema)
210✔
652
        result := d.conn.QueryRowContext(ctx, query, indexName, tableName)
210✔
653
        err := result.Scan(&nullDDL)
210✔
654
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
210✔
655
                return nil, fmt.Errorf("failed retrieving index DDL statement: %w", err)
×
656
        }
×
657

658
        if !nullDDL.Valid {
360✔
659
                return nil, nil
150✔
660
        }
150✔
661

662
        ddl := nullDDL.String
60✔
663
        // We're following the parsing logic from the `intckParseCreateIndex` function in the SQLite source code.
60✔
664
        // 1. https://github.com/sqlite/sqlite/blob/1d8cde9d56d153767e98595c4b015221864ef0e7/ext/intck/sqlite3intck.c#L363
60✔
665
        // 2. https://www.sqlite.org/lang_createindex.html
60✔
666

60✔
667
        // skip forward until the first "(" token
60✔
668
        i := strings.Index(ddl, "(")
60✔
669
        if i == -1 {
60✔
670
                return nil, fmt.Errorf("failed locating first column: %w", err)
×
671
        }
×
672
        ddl = ddl[i+1:]
60✔
673
        // discard the WHERE clause fragment (if one exists)
60✔
674
        i = strings.LastIndex(ddl, ")")
60✔
675
        if i == -1 {
60✔
676
                return nil, fmt.Errorf("failed locating last column: %w", err)
×
677
        }
×
678
        ddl = ddl[:i]
60✔
679
        // organize column definitions into a list
60✔
680
        colDefs := d.splitColumnDefinitions(ddl)
60✔
681

60✔
682
        expressions := make([]string, len(colDefs))
60✔
683
        for seqNo, expression := range colDefs {
150✔
684
                expressions[seqNo] = strings.TrimSpace(expression)
90✔
685
        }
90✔
686

687
        return expressions, nil
60✔
688
}
689

690
// splitColumnDefinitions performs an intelligent split of the DDL part defining the index columns.
691
//
692
// We cannot perform a simple `strings.Split(ddl, ",")` as `ddl` could contain functional expressions, i.e.:
693
//
694
//        sql  := CREATE INDEX idx ON test (col1, (col2 + col3), (POW(col3, 2)));
695
//        ddl  := "col1, (col2 + col3), (POW(col3, 2))"
696
//        defs := []string{"col1", "(col2 + col3)", "(POW(col3, 2))"}
697
func (d driver) splitColumnDefinitions(ddl string) []string {
60✔
698
        var defs []string
60✔
699
        var i, pOpen int
60✔
700

60✔
701
        for j := range len(ddl) {
930✔
702
                if ddl[j] == '(' {
910✔
703
                        pOpen++
40✔
704
                }
40✔
705
                if ddl[j] == ')' {
910✔
706
                        pOpen--
40✔
707
                }
40✔
708
                if pOpen == 0 && ddl[j] == ',' {
900✔
709
                        defs = append(defs, ddl[i:j])
30✔
710
                        i = j + 1
30✔
711
                }
30✔
712
        }
713

714
        if i < len(ddl) {
120✔
715
                defs = append(defs, ddl[i:])
60✔
716
        }
60✔
717

718
        return defs
60✔
719
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc