• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stephenafamo / bob / 14351235503

09 Apr 2025 07:14AM UTC coverage: 47.87% (-1.5%) from 49.32%
14351235503

Pull #388

github

stephenafamo
Implement parsing of SQLite SELECT queries
Pull Request #388: Implement parsing of SQLite SELECT queries

1093 of 2670 new or added lines in 29 files covered. (40.94%)

4 existing lines in 4 files now uncovered.

7471 of 15607 relevant lines covered (47.87%)

240.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.27
/gen/bobgen-sqlite/driver/sqlite.go
1
package driver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "errors"
7
        "fmt"
8
        "net/url"
9
        "sort"
10
        "strconv"
11
        "strings"
12

13
        "github.com/aarondl/opt/null"
14
        helpers "github.com/stephenafamo/bob/gen/bobgen-helpers"
15
        "github.com/stephenafamo/bob/gen/bobgen-sqlite/driver/parser"
16
        "github.com/stephenafamo/bob/gen/drivers"
17
        "github.com/stephenafamo/scan"
18
        "github.com/stephenafamo/scan/stdscan"
19
        _ "github.com/tursodatabase/libsql-client-go/libsql"
20
        "github.com/volatiletech/strmangle"
21
        _ "modernc.org/sqlite"
22
)
23

24
type (
25
        Interface  = drivers.Interface[any, any, IndexExtra]
26
        DBInfo     = drivers.DBInfo[any, any, IndexExtra]
27
        IndexExtra = parser.IndexExtra
28
)
29

30
func New(config Config) Interface {
26✔
31
        if config.DriverName == "" {
52✔
32
                config.DriverName = "modernc.org/sqlite"
26✔
33
        }
26✔
34
        return &driver{config: config}
26✔
35
}
36

37
type Config struct {
38
        // The database connection string
39
        DSN string
40
        // The database schemas to generate models for
41
        // a map of the schema name to the DSN
42
        Attach map[string]string
43
        // Folders containing query files
44
        Queries []string `yaml:"queries"`
45
        // The name of this schema will not be included in the generated models
46
        // a context value can then be used to set the schema at runtime
47
        // useful for multi-tenant setups
48
        SharedSchema string `yaml:"shared_schema"`
49
        // List of tables that will be included. Others are ignored
50
        Only map[string][]string
51
        // List of tables that will be should be ignored. Others are included
52
        Except map[string][]string
53
        // Which `database/sql` driver to use (the full module name)
54
        DriverName string `yaml:"driver_name"`
55

56
        // Used in main.go
57

58
        // The name of the folder to output the models package to
59
        Output string
60
        // The name you wish to assign to your generated models package
61
        Pkgname   string
62
        NoFactory bool `yaml:"no_factory"`
63
}
64

65
// driver holds the database connection string and a handle
66
// to the database connection.
67
type driver struct {
68
        config Config
69
        conn   *sql.DB
70
}
71

72
func (d *driver) Dialect() string {
24✔
73
        return "sqlite"
24✔
74
}
24✔
75

76
func (d *driver) Destination() string {
×
77
        return d.config.Output
×
78
}
×
79

80
func (d *driver) PackageName() string {
×
81
        return d.config.Pkgname
×
82
}
×
83

84
func (d *driver) Types() drivers.Types {
12✔
85
        return helpers.Types()
12✔
86
}
12✔
87

88
// Assemble all the information we need to provide back to the driver
89
func (d *driver) Assemble(ctx context.Context) (*DBInfo, error) {
26✔
90
        var err error
26✔
91

26✔
92
        if d.config.SharedSchema == "" {
52✔
93
                d.config.SharedSchema = "main"
26✔
94
        }
26✔
95

96
        if d.config.DSN == "" {
26✔
97
                return nil, fmt.Errorf("database dsn is not set")
×
98
        }
×
99

100
        driverName := d.inferDriver()
26✔
101
        d.conn, err = sql.Open(driverName, d.config.DSN)
26✔
102
        if err != nil {
26✔
103
                return nil, fmt.Errorf("failed to connect to database: %w", err)
×
104
        }
×
105
        defer d.conn.Close()
26✔
106

26✔
107
        for schema, dsn := range d.config.Attach {
52✔
108
                if driverName == "sqlite" {
40✔
109
                        dsn = strconv.Quote(dsn)
14✔
110
                }
14✔
111
                _, err = d.conn.ExecContext(ctx, fmt.Sprintf("attach database %s as %s", dsn, schema))
26✔
112
                if err != nil {
26✔
113
                        return nil, fmt.Errorf("could not attach %q: %w", schema, err)
×
114
                }
×
115
        }
116

117
        tables, err := d.tables(ctx)
26✔
118
        if err != nil {
26✔
NEW
119
                return nil, fmt.Errorf("getting tables: %w", err)
×
NEW
120
        }
×
121

122
        queries, err := parser.New(tables).ParseFolders(d.config.Queries...)
26✔
123
        if err != nil {
26✔
NEW
124
                return nil, fmt.Errorf("parse query folders: %w", err)
×
UNCOV
125
        }
×
126

127
        if driverName == "libsql" {
38✔
128
                d.config.DriverName = "github.com/tursodatabase/libsql-client-go/libsql"
12✔
129
        }
12✔
130
        dbinfo := &DBInfo{
26✔
131
                DriverName:   d.config.DriverName,
26✔
132
                Tables:       tables,
26✔
133
                QueryFolders: queries,
26✔
134
        }
26✔
135

26✔
136
        return dbinfo, nil
26✔
137
}
138

139
func (d *driver) inferDriver() string {
26✔
140
        driverName := "sqlite"
26✔
141
        if !strings.Contains(d.config.DSN, "://") {
40✔
142
                return driverName
14✔
143
        }
14✔
144
        dsn, _ := url.Parse(d.config.DSN)
12✔
145
        if dsn == nil {
12✔
146
                return driverName
×
147
        }
×
148
        libsqlSchemes := map[string]bool{
12✔
149
                "libsql": true,
12✔
150
                "file":   true,
12✔
151
                "https":  true,
12✔
152
                "http":   true,
12✔
153
                "wss":    true,
12✔
154
                "ws":     true,
12✔
155
        }
12✔
156
        if libsqlSchemes[dsn.Scheme] {
24✔
157
                driverName = "libsql"
12✔
158
        }
12✔
159
        return driverName
12✔
160
}
161

162
func (d *driver) buildQuery(schema string) (string, []any) {
52✔
163
        var args []any
52✔
164
        query := fmt.Sprintf(`SELECT name FROM %q.sqlite_schema WHERE name NOT LIKE 'sqlite_%%' AND type IN ('table', 'view')`, schema)
52✔
165

52✔
166
        tableFilter := drivers.ParseTableFilter(d.config.Only, d.config.Except)
52✔
167

52✔
168
        if len(tableFilter.Only) > 0 {
84✔
169
                var subqueries []string
32✔
170
                stringPatterns, regexPatterns := tableFilter.ClassifyPatterns(tableFilter.Only)
32✔
171
                include := make([]string, 0, len(stringPatterns))
32✔
172
                for _, name := range stringPatterns {
128✔
173
                        if (schema == "main" && !strings.Contains(name, ".")) || strings.HasPrefix(name, schema+".") {
144✔
174
                                include = append(include, strings.TrimPrefix(name, schema+"."))
48✔
175
                        }
48✔
176
                }
177
                if len(include) > 0 {
56✔
178
                        subqueries = append(subqueries, fmt.Sprintf("name in (%s)", strmangle.Placeholders(false, len(include), 1, 1)))
24✔
179
                        for _, w := range include {
72✔
180
                                args = append(args, w)
48✔
181
                        }
48✔
182
                }
183
                if len(regexPatterns) > 0 {
48✔
184
                        subqueries = append(subqueries, fmt.Sprintf("name regexp (%s)", strmangle.Placeholders(false, 1, len(args)+1, 1)))
16✔
185
                        args = append(args, strings.Join(regexPatterns, "|"))
16✔
186
                }
16✔
187
                if len(subqueries) > 0 {
64✔
188
                        query += fmt.Sprintf(" and (%s)", strings.Join(subqueries, " or "))
32✔
189
                }
32✔
190
        }
191

192
        if len(tableFilter.Except) > 0 {
84✔
193
                var subqueries []string
32✔
194
                stringPatterns, regexPatterns := tableFilter.ClassifyPatterns(tableFilter.Except)
32✔
195
                exclude := make([]string, 0, len(tableFilter.Except))
32✔
196
                for _, name := range stringPatterns {
128✔
197
                        if (schema == "main" && !strings.Contains(name, ".")) || strings.HasPrefix(name, schema+".") {
144✔
198
                                exclude = append(exclude, strings.TrimPrefix(name, schema+"."))
48✔
199
                        }
48✔
200
                }
201
                if len(exclude) > 0 {
56✔
202
                        subqueries = append(subqueries, fmt.Sprintf("name not in (%s)", strmangle.Placeholders(false, len(exclude), 1+len(args), 1)))
24✔
203
                        for _, w := range exclude {
72✔
204
                                args = append(args, w)
48✔
205
                        }
48✔
206
                }
207
                if len(regexPatterns) > 0 {
48✔
208
                        subqueries = append(subqueries, fmt.Sprintf("name not regexp (%s)", strmangle.Placeholders(false, 1, len(args)+1, 1)))
16✔
209
                        args = append(args, strings.Join(regexPatterns, "|"))
16✔
210
                }
16✔
211
                if len(subqueries) > 0 {
64✔
212
                        query += fmt.Sprintf(" and (%s)", strings.Join(subqueries, " and "))
32✔
213
                }
32✔
214
        }
215

216
        query += ` ORDER BY type, name`
52✔
217

52✔
218
        return query, args
52✔
219
}
220

221
func (d *driver) tables(ctx context.Context) (drivers.Tables[any, IndexExtra], error) {
26✔
222
        mainQuery, mainArgs := d.buildQuery("main")
26✔
223
        mainTables, err := stdscan.All(ctx, d.conn, scan.SingleColumnMapper[string], mainQuery, mainArgs...)
26✔
224
        if err != nil {
26✔
225
                return nil, err
×
226
        }
×
227

228
        colFilter := drivers.ParseColumnFilter(mainTables, d.config.Only, d.config.Except)
26✔
229
        allTables := make(drivers.Tables[any, IndexExtra], len(mainTables))
26✔
230
        for i, name := range mainTables {
206✔
231
                allTables[i], err = d.getTable(ctx, "main", name, colFilter)
180✔
232
                if err != nil {
180✔
233
                        return nil, err
×
234
                }
×
235
        }
236

237
        for schema := range d.config.Attach {
52✔
238
                schemaQuery, schemaArgs := d.buildQuery(schema)
26✔
239
                tables, err := stdscan.All(ctx, d.conn, scan.SingleColumnMapper[string], schemaQuery, schemaArgs...)
26✔
240
                if err != nil {
26✔
241
                        return nil, err
×
242
                }
×
243
                colFilter = drivers.ParseColumnFilter(tables, d.config.Only, d.config.Except)
26✔
244
                for _, name := range tables {
186✔
245
                        table, err := d.getTable(ctx, schema, name, colFilter)
160✔
246
                        if err != nil {
160✔
247
                                return nil, err
×
248
                        }
×
249
                        allTables = append(allTables, table)
160✔
250
                }
251
        }
252

253
        return allTables, nil
26✔
254
}
255

256
func (d driver) getTable(ctx context.Context, schema, name string, colFilter drivers.ColumnFilter) (drivers.Table[any, IndexExtra], error) {
340✔
257
        var err error
340✔
258

340✔
259
        table := drivers.Table[any, IndexExtra]{
340✔
260
                Key:    d.key(schema, name),
340✔
261
                Schema: d.schema(schema),
340✔
262
                Name:   name,
340✔
263
        }
340✔
264

340✔
265
        tinfo, err := d.tableInfo(ctx, schema, name)
340✔
266
        if err != nil {
340✔
267
                return table, err
×
268
        }
×
269

270
        table.Columns, err = d.columns(ctx, schema, name, tinfo, colFilter)
340✔
271
        if err != nil {
340✔
272
                return table, err
×
273
        }
×
274

275
        // We cannot rely on the indexes to get the primary key
276
        // because it is not always included in the indexes
277
        table.Constraints.Primary = d.primaryKey(schema, name, tinfo)
340✔
278
        table.Constraints.Foreign, err = d.foreignKeys(ctx, schema, name)
340✔
279
        if err != nil {
340✔
280
                return table, err
×
281
        }
×
282

283
        table.Indexes, err = d.indexes(ctx, schema, name)
340✔
284
        if err != nil {
340✔
285
                return table, err
×
286
        }
×
287

288
        // Get Unique constraints from indexes
289
        // Also check if the primary key is in the indexes
290
        hasPk := false
340✔
291
        for _, index := range table.Indexes {
550✔
292
                constraint := drivers.Constraint[any]{
210✔
293
                        Name:    index.Name,
210✔
294
                        Columns: index.NonExpressionColumns(),
210✔
295
                }
210✔
296

210✔
297
                switch index.Type {
210✔
298
                case "pk":
110✔
299
                        hasPk = true
110✔
300
                case "u":
40✔
301
                        table.Constraints.Uniques = append(table.Constraints.Uniques, constraint)
40✔
302
                }
303
        }
304

305
        // Add the primary key to the indexes if it is not already there
306
        if !hasPk && table.Constraints.Primary != nil {
540✔
307
                pkIndex := drivers.Index[IndexExtra]{
200✔
308
                        Type:    "pk",
200✔
309
                        Name:    table.Constraints.Primary.Name,
200✔
310
                        Columns: make([]drivers.IndexColumn, len(table.Constraints.Primary.Columns)),
200✔
311
                        Unique:  true,
200✔
312
                }
200✔
313

200✔
314
                for i, col := range table.Constraints.Primary.Columns {
400✔
315
                        pkIndex.Columns[i] = drivers.IndexColumn{
200✔
316
                                Name: col,
200✔
317
                        }
200✔
318
                }
200✔
319

320
                // List the primary key first
321
                table.Indexes = append([]drivers.Index[IndexExtra]{pkIndex}, table.Indexes...)
200✔
322
        }
323

324
        return table, nil
340✔
325
}
326

327
// Columns takes a table name and attempts to retrieve the table information
328
// from the database. It retrieves the column names
329
// and column types and returns those as a []Column after TranslateColumnType()
330
// converts the SQL types to Go types, for example: "varchar" to "string"
331
func (d driver) columns(ctx context.Context, schema, tableName string, tinfo []info, colFilter drivers.ColumnFilter) ([]drivers.Column, error) {
340✔
332
        var columns []drivers.Column //nolint:prealloc
340✔
333

340✔
334
        //nolint:gosec
340✔
335
        query := fmt.Sprintf("SELECT 1 FROM '%s'.sqlite_master WHERE type = 'table' AND name = ? AND sql LIKE '%%AUTOINCREMENT%%'", schema)
340✔
336
        result, err := d.conn.QueryContext(ctx, query, tableName)
340✔
337
        if err != nil {
340✔
338
                return nil, fmt.Errorf("autoincr query: %w", err)
×
339
        }
×
340
        tableHasAutoIncr := result.Next()
340✔
341
        if err := result.Close(); err != nil {
340✔
342
                return nil, err
×
343
        }
×
344

345
        nPkeys := 0
340✔
346
        for _, column := range tinfo {
2,190✔
347
                if column.Pk != 0 {
2,180✔
348
                        nPkeys++
330✔
349
                }
330✔
350
        }
351

352
        filter := colFilter[tableName]
340✔
353
        excludedColumns := make(map[string]struct{}, len(filter.Except))
340✔
354
        if len(filter.Except) > 0 {
444✔
355
                for _, w := range filter.Except {
208✔
356
                        excludedColumns[w] = struct{}{}
104✔
357
                }
104✔
358
        }
359

360
        for _, colInfo := range tinfo {
2,190✔
361
                if _, ok := excludedColumns[colInfo.Name]; ok {
1,874✔
362
                        continue
24✔
363
                }
364
                column := drivers.Column{
1,826✔
365
                        Name:     colInfo.Name,
1,826✔
366
                        DBType:   strings.ToUpper(colInfo.Type),
1,826✔
367
                        Nullable: !colInfo.NotNull && colInfo.Pk < 1,
1,826✔
368
                }
1,826✔
369

1,826✔
370
                isPrimaryKeyInteger := colInfo.Pk == 1 && column.DBType == "INTEGER"
1,826✔
371
                // This is special behavior noted in the sqlite documentation.
1,826✔
372
                // An integer primary key becomes synonymous with the internal ROWID
1,826✔
373
                // and acts as an auto incrementing value. Although there's important
1,826✔
374
                // differences between using the keyword AUTOINCREMENT and this inferred
1,826✔
375
                // version, they don't matter here so just masquerade as the same thing as
1,826✔
376
                // above.
1,826✔
377
                autoIncr := isPrimaryKeyInteger && (tableHasAutoIncr || nPkeys == 1)
1,826✔
378

1,826✔
379
                // See: https://github.com/sqlite/sqlite/blob/91f621531dc1cb9ba5f6a47eb51b1de9ed8bdd07/src/pragma.c#L1165
1,826✔
380
                column.Generated = colInfo.Hidden == 2 || colInfo.Hidden == 3
1,826✔
381

1,826✔
382
                if colInfo.DefaultValue.Valid {
2,166✔
383
                        column.Default = colInfo.DefaultValue.String
340✔
384
                } else if autoIncr {
2,026✔
385
                        column.Default = "auto_increment"
200✔
386
                } else if column.Generated {
1,526✔
387
                        column.Default = "auto_generated"
40✔
388
                }
40✔
389

390
                if column.Nullable && column.Default == "" {
2,476✔
391
                        column.Default = "NULL"
650✔
392
                }
650✔
393

394
                column.Type = parser.TranslateColumnType(column.DBType)
1,826✔
395
                columns = append(columns, column)
1,826✔
396
        }
397

398
        return columns, nil
340✔
399
}
400

401
func (s driver) tableInfo(ctx context.Context, schema, tableName string) ([]info, error) {
340✔
402
        var ret []info
340✔
403
        rows, err := s.conn.QueryContext(ctx, fmt.Sprintf("PRAGMA '%s'.table_xinfo('%s')", schema, tableName))
340✔
404
        if err != nil {
340✔
405
                return nil, err
×
406
        }
×
407
        defer rows.Close()
340✔
408

340✔
409
        for rows.Next() {
2,190✔
410
                tinfo := info{}
1,850✔
411
                if err := rows.Scan(&tinfo.Cid, &tinfo.Name, &tinfo.Type, &tinfo.NotNull, &tinfo.DefaultValue, &tinfo.Pk, &tinfo.Hidden); err != nil {
1,850✔
412
                        return nil, fmt.Errorf("unable to scan for table %s: %w", tableName, err)
×
413
                }
×
414

415
                ret = append(ret, tinfo)
1,850✔
416
        }
417
        return ret, nil
340✔
418
}
419

420
// primaryKey looks up the primary key for a table.
421
func (s driver) primaryKey(schema, tableName string, tinfo []info) *drivers.Constraint[any] {
340✔
422
        var cols []string
340✔
423
        for _, c := range tinfo {
2,190✔
424
                if c.Pk > 0 {
2,180✔
425
                        cols = append(cols, c.Name)
330✔
426
                }
330✔
427
        }
428

429
        if len(cols) == 0 {
370✔
430
                return nil
30✔
431
        }
30✔
432

433
        return &drivers.Constraint[any]{
310✔
434
                Name:    fmt.Sprintf("pk_%s_%s", schema, tableName),
310✔
435
                Columns: cols,
310✔
436
        }
310✔
437
}
438

439
func (d driver) skipKey(table, column string) bool {
110✔
440
        if len(d.config.Only) > 0 {
110✔
441
                // check if the table is listed at all
×
442
                filter, ok := d.config.Only[table]
×
443
                if !ok {
×
444
                        return true
×
445
                }
×
446

447
                // check if the column is listed
448
                if len(filter) == 0 {
×
449
                        return false
×
450
                }
×
451

452
                for _, filteredCol := range filter {
×
453
                        if filteredCol == column {
×
454
                                return false
×
455
                        }
×
456
                }
457
                return true
×
458
        }
459

460
        if len(d.config.Except) > 0 {
154✔
461
                filter, ok := d.config.Except[table]
44✔
462
                if !ok {
88✔
463
                        return false
44✔
464
                }
44✔
465

466
                if len(filter) == 0 {
×
467
                        return true
×
468
                }
×
469

470
                for _, filteredCol := range filter {
×
471
                        if filteredCol == column {
×
472
                                return true
×
473
                        }
×
474
                }
475
        }
476

477
        return false
66✔
478
}
479

480
// foreignKeys retrieves the foreign keys for a given table name.
481
func (d driver) foreignKeys(ctx context.Context, schema, tableName string) ([]drivers.ForeignKey[any], error) {
340✔
482
        rows, err := d.conn.QueryContext(ctx, fmt.Sprintf("PRAGMA '%s'.foreign_key_list('%s')", schema, tableName))
340✔
483
        if err != nil {
340✔
484
                return nil, err
×
485
        }
×
486
        defer rows.Close()
340✔
487

340✔
488
        fkeyMap := make(map[int]drivers.ForeignKey[any])
340✔
489
        for rows.Next() {
450✔
490
                var id, seq int
110✔
491
                var ftable, col string
110✔
492
                var fcolNullable null.Val[string]
110✔
493

110✔
494
                // not used
110✔
495
                var onupdate, ondelete, match string
110✔
496

110✔
497
                err = rows.Scan(&id, &seq, &ftable, &col, &fcolNullable, &onupdate, &ondelete, &match)
110✔
498
                if err != nil {
110✔
499
                        return nil, err
×
500
                }
×
501

502
                fullFtable := ftable
110✔
503
                if schema != "main" {
150✔
504
                        fullFtable = fmt.Sprintf("%s.%s", schema, ftable)
40✔
505
                }
40✔
506

507
                fcol, _ := fcolNullable.Get()
110✔
508
                if fcol == "" {
120✔
509
                        fcol, err = stdscan.One(
10✔
510
                                ctx, d.conn, scan.SingleColumnMapper[string],
10✔
511
                                fmt.Sprintf("SELECT name FROM pragma_table_info('%s', '%s') WHERE pk = ?", ftable, schema), seq+1,
10✔
512
                        )
10✔
513
                        if err != nil {
10✔
514
                                return nil, fmt.Errorf("could not find column %q in table %q: %w", col, ftable, err)
×
515
                        }
×
516
                }
517

518
                if d.skipKey(fullFtable, fcol) {
110✔
519
                        continue
×
520
                }
521

522
                fkeyMap[id] = drivers.ForeignKey[any]{
110✔
523
                        Constraint: drivers.Constraint[any]{
110✔
524
                                Name:    fmt.Sprintf("fk_%s_%d", tableName, id),
110✔
525
                                Columns: append(fkeyMap[id].Columns, col),
110✔
526
                        },
110✔
527
                        ForeignTable:   d.key(schema, ftable),
110✔
528
                        ForeignColumns: append(fkeyMap[id].ForeignColumns, fcol),
110✔
529
                }
110✔
530
        }
531

532
        if err = rows.Err(); err != nil {
340✔
533
                return nil, err
×
534
        }
×
535

536
        fkeys := make([]drivers.ForeignKey[any], 0, len(fkeyMap))
340✔
537

340✔
538
        for _, fkey := range fkeyMap {
440✔
539
                fkeys = append(fkeys, fkey)
100✔
540
        }
100✔
541

542
        sort.Slice(fkeys, func(i, j int) bool {
380✔
543
                return fkeys[i].Name < fkeys[j].Name
40✔
544
        })
40✔
545

546
        return fkeys, nil
340✔
547
}
548

549
// uniques retrieves the unique keys for a given table name.
550

551
type info struct {
552
        Cid          int
553
        Name         string
554
        Type         string
555
        NotNull      bool
556
        DefaultValue sql.NullString
557
        Pk           int
558
        Hidden       int
559
}
560

561
func (d *driver) key(schema string, table string) string {
450✔
562
        key := table
450✔
563
        if schema != "" && schema != d.config.SharedSchema {
650✔
564
                key = schema + "." + table
200✔
565
        }
200✔
566

567
        return key
450✔
568
}
569

570
func (d *driver) schema(schema string) string {
340✔
571
        if schema == d.config.SharedSchema {
520✔
572
                return ""
180✔
573
        }
180✔
574

575
        return schema
160✔
576
}
577

578
func (d *driver) indexes(ctx context.Context, schema, tableName string) ([]drivers.Index[IndexExtra], error) {
340✔
579
        query := fmt.Sprintf(`
340✔
580
        SELECT name, "unique", origin, partial
340✔
581
        FROM pragma_index_list('%s', '%s') ORDER BY seq ASC
340✔
582
        `, tableName, schema)
340✔
583
        indexNames, err := stdscan.All(ctx, d.conn, scan.StructMapper[struct {
340✔
584
                Name    string
340✔
585
                Unique  bool
340✔
586
                Origin  string
340✔
587
                Partial bool
340✔
588
        }](), query)
340✔
589
        if err != nil {
340✔
590
                return nil, err
×
591
        }
×
592

593
        indexes := make([]drivers.Index[IndexExtra], len(indexNames))
340✔
594
        for i, index := range indexNames {
550✔
595
                cols, err := d.getIndexInformation(ctx, schema, tableName, index.Name)
210✔
596
                if err != nil {
210✔
597
                        return nil, err
×
598
                }
×
599
                indexes[i] = drivers.Index[IndexExtra]{
210✔
600
                        Type:    index.Origin,
210✔
601
                        Name:    index.Name,
210✔
602
                        Unique:  index.Unique,
210✔
603
                        Columns: cols,
210✔
604
                        Extra: IndexExtra{
210✔
605
                                Partial: index.Partial,
210✔
606
                        },
210✔
607
                }
210✔
608

609
        }
610

611
        return indexes, nil
340✔
612
}
613

614
func (d *driver) getIndexInformation(ctx context.Context, schema, tableName, indexName string) ([]drivers.IndexColumn, error) {
210✔
615
        colExpressions, err := d.extractIndexExpressions(ctx, schema, tableName, indexName)
210✔
616
        if err != nil {
210✔
617
                return nil, err
×
618
        }
×
619

620
        query := fmt.Sprintf(`
210✔
621
            SELECT seqno, name, desc
210✔
622
            FROM pragma_index_xinfo('%s', '%s')
210✔
623
            WHERE key = 1
210✔
624
            ORDER BY seqno ASC`,
210✔
625
                indexName, schema)
210✔
626

210✔
627
        var columns []drivers.IndexColumn //nolint:prealloc
210✔
628
        for column, err := range stdscan.Each(ctx, d.conn, scan.StructMapper[struct {
210✔
629
                Seqno int
210✔
630
                Name  sql.NullString
210✔
631
                Desc  bool
210✔
632
        }](), query) {
480✔
633
                if err != nil {
270✔
634
                        return nil, err
×
635
                }
×
636

637
                col := drivers.IndexColumn{
270✔
638
                        Name: column.Name.String,
270✔
639
                        Desc: column.Desc,
270✔
640
                }
270✔
641

270✔
642
                if !column.Name.Valid {
310✔
643
                        col.Name = colExpressions[column.Seqno]
40✔
644
                        col.IsExpression = true
40✔
645
                }
40✔
646

647
                columns = append(columns, col)
270✔
648
        }
649

650
        return columns, nil
210✔
651
}
652

653
func (d driver) extractIndexExpressions(ctx context.Context, schema, tableName, indexName string) ([]string, error) {
210✔
654
        var nullDDL sql.NullString
210✔
655

210✔
656
        //nolint:gosec
210✔
657
        query := fmt.Sprintf("SELECT sql FROM '%s'.sqlite_master WHERE type = 'index' AND name = ? AND tbl_name = ?", schema)
210✔
658
        result := d.conn.QueryRowContext(ctx, query, indexName, tableName)
210✔
659
        err := result.Scan(&nullDDL)
210✔
660
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
210✔
661
                return nil, fmt.Errorf("failed retrieving index DDL statement: %w", err)
×
662
        }
×
663

664
        if !nullDDL.Valid {
360✔
665
                return nil, nil
150✔
666
        }
150✔
667

668
        ddl := nullDDL.String
60✔
669
        // We're following the parsing logic from the `intckParseCreateIndex` function in the SQLite source code.
60✔
670
        // 1. https://github.com/sqlite/sqlite/blob/1d8cde9d56d153767e98595c4b015221864ef0e7/ext/intck/sqlite3intck.c#L363
60✔
671
        // 2. https://www.sqlite.org/lang_createindex.html
60✔
672

60✔
673
        // skip forward until the first "(" token
60✔
674
        i := strings.Index(ddl, "(")
60✔
675
        if i == -1 {
60✔
676
                return nil, fmt.Errorf("failed locating first column: %w", err)
×
677
        }
×
678
        ddl = ddl[i+1:]
60✔
679
        // discard the WHERE clause fragment (if one exists)
60✔
680
        i = strings.LastIndex(ddl, ")")
60✔
681
        if i == -1 {
60✔
682
                return nil, fmt.Errorf("failed locating last column: %w", err)
×
683
        }
×
684
        ddl = ddl[:i]
60✔
685
        // organize column definitions into a list
60✔
686
        colDefs := d.splitColumnDefinitions(ddl)
60✔
687

60✔
688
        expressions := make([]string, len(colDefs))
60✔
689
        for seqNo, expression := range colDefs {
150✔
690
                expressions[seqNo] = strings.TrimSpace(expression)
90✔
691
        }
90✔
692

693
        return expressions, nil
60✔
694
}
695

696
// splitColumnDefinitions performs an intelligent split of the DDL part defining the index columns.
697
//
698
// We cannot perform a simple `strings.Split(ddl, ",")` as `ddl` could contain functional expressions, i.e.:
699
//
700
//        sql  := CREATE INDEX idx ON test (col1, (col2 + col3), (POW(col3, 2)));
701
//        ddl  := "col1, (col2 + col3), (POW(col3, 2))"
702
//        defs := []string{"col1", "(col2 + col3)", "(POW(col3, 2))"}
703
func (d driver) splitColumnDefinitions(ddl string) []string {
60✔
704
        var defs []string
60✔
705
        var i, pOpen int
60✔
706

60✔
707
        for j := 0; j < len(ddl); j++ {
930✔
708
                if ddl[j] == '(' {
910✔
709
                        pOpen++
40✔
710
                }
40✔
711
                if ddl[j] == ')' {
910✔
712
                        pOpen--
40✔
713
                }
40✔
714
                if pOpen == 0 && ddl[j] == ',' {
900✔
715
                        defs = append(defs, ddl[i:j])
30✔
716
                        i = j + 1
30✔
717
                }
30✔
718
        }
719

720
        if i < len(ddl) {
120✔
721
                defs = append(defs, ddl[i:])
60✔
722
        }
60✔
723

724
        return defs
60✔
725
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc