• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stephenafamo / bob / 14351235503

09 Apr 2025 07:14AM UTC coverage: 47.87% (-1.5%) from 49.32%
14351235503

Pull #388

github

stephenafamo
Implement parsing of SQLite SELECT queries
Pull Request #388: Implement parsing of SQLite SELECT queries

1093 of 2670 new or added lines in 29 files covered. (40.94%)

4 existing lines in 4 files now uncovered.

7471 of 15607 relevant lines covered (47.87%)

240.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

15.87
/gen/bobgen-sqlite/driver/parser/parse.go
1
package parser
2

3
import (
4
        "errors"
5
        "fmt"
6
        "os"
7
        "path/filepath"
8
        "strings"
9

10
        "github.com/aarondl/opt/omit"
11
        "github.com/antlr4-go/antlr/v4"
12
        "github.com/stephenafamo/bob/gen/drivers"
13
        "github.com/stephenafamo/bob/internal"
14
        sqliteparser "github.com/stephenafamo/sqlparser/sqlite"
15
)
16

17
func New(t tables) Parser {
28✔
18
        return Parser{db: t}
28✔
19
}
28✔
20

21
type Parser struct {
22
        db tables
23
}
24

25
func (p Parser) ParseFolders(paths ...string) ([]drivers.QueryFolder, error) {
26✔
26
        allQueries := make([]drivers.QueryFolder, 0, len(paths))
26✔
27
        for _, path := range paths {
26✔
NEW
28
                queries, err := p.parseFolder(path)
×
NEW
29
                if err != nil {
×
NEW
30
                        return nil, fmt.Errorf("parse folder: %w", err)
×
NEW
31
                }
×
32

NEW
33
                allQueries = append(allQueries, queries)
×
34
        }
35

36
        return allQueries, nil
26✔
37
}
38

NEW
39
func (p Parser) parseFolder(path string) (drivers.QueryFolder, error) {
×
NEW
40
        entries, err := os.ReadDir(path)
×
NEW
41
        if err != nil {
×
NEW
42
                return drivers.QueryFolder{}, fmt.Errorf("read dir: %w", err)
×
NEW
43
        }
×
44

NEW
45
        files := make([]drivers.QueryFile, 0, len(entries))
×
NEW
46
        for _, entry := range entries {
×
NEW
47
                if entry.IsDir() {
×
NEW
48
                        continue
×
49
                }
50

NEW
51
                if filepath.Ext(entry.Name()) != ".sql" {
×
NEW
52
                        continue
×
53
                }
54

NEW
55
                file, err := p.parseFile(filepath.Join(path, entry.Name()))
×
NEW
56
                if err != nil {
×
NEW
57
                        return drivers.QueryFolder{}, fmt.Errorf("parse file: %w", err)
×
NEW
58
                }
×
59

NEW
60
                files = append(files, file)
×
61
        }
62

NEW
63
        return drivers.QueryFolder{
×
NEW
64
                Path:  path,
×
NEW
65
                Files: files,
×
NEW
66
        }, nil
×
67
}
68

NEW
69
func (p Parser) parseFile(path string) (drivers.QueryFile, error) {
×
NEW
70
        file, err := os.ReadFile(path)
×
NEW
71
        if err != nil {
×
NEW
72
                return drivers.QueryFile{}, fmt.Errorf("read file: %w", err)
×
NEW
73
        }
×
74

NEW
75
        queries, err := p.parseMultiQueries(string(file))
×
NEW
76
        if err != nil {
×
NEW
77
                return drivers.QueryFile{}, fmt.Errorf("parse multi queries: %w", err)
×
NEW
78
        }
×
79

NEW
80
        return drivers.QueryFile{
×
NEW
81
                Path:    path,
×
NEW
82
                Queries: queries,
×
NEW
83
        }, nil
×
84
}
85

NEW
86
func (p Parser) parseMultiQueries(s string) ([]drivers.Query, error) {
×
NEW
87
        v := NewVisitor(p.db)
×
NEW
88
        infos, err := p.parse(v, s)
×
NEW
89
        if err != nil {
×
NEW
90
                return nil, fmt.Errorf("parse: %w", err)
×
NEW
91
        }
×
92

NEW
93
        queries := make([]drivers.Query, len(infos))
×
NEW
94
        for i, info := range infos {
×
NEW
95
                stmtStart := info.stmt.GetStart().GetStart()
×
NEW
96
                stmtStop := info.stmt.GetStop().GetStop()
×
NEW
97
                formatted, err := internal.EditStringSegment(s, stmtStart, stmtStop, info.editRules...)
×
NEW
98
                if err != nil {
×
NEW
99
                        return nil, fmt.Errorf("format: %w", err)
×
NEW
100
                }
×
101

NEW
102
                cols := make([]drivers.QueryCol, len(info.columns))
×
NEW
103
                for i, col := range info.columns {
×
NEW
104
                        cols[i] = drivers.QueryCol{
×
NEW
105
                                Name:     col.name,
×
NEW
106
                                DBName:   col.name,
×
NEW
107
                                Nullable: omit.From(col.typ.Nullable()),
×
NEW
108
                                TypeName: col.typ.Type(p.db),
×
NEW
109
                        }.Merge(col.config)
×
NEW
110
                }
×
111

NEW
112
                args := v.getArgs(stmtStart, stmtStop)
×
NEW
113
                keys := make(map[string]int, len(args))
×
NEW
114
                names := make(map[string]int, len(args))
×
NEW
115
                queryArgs := make([]drivers.QueryArg, 0, len(args))
×
NEW
116
                for _, arg := range args {
×
NEW
117
                        // If the key is already in the map, append the position
×
NEW
118
                        key := arg.queryArgKey
×
NEW
119
                        if oldIndex, ok := keys[key]; ok && key != "" {
×
NEW
120
                                queryArgs[oldIndex].Positions = append(
×
NEW
121
                                        queryArgs[oldIndex].Positions, arg.EditedPosition,
×
NEW
122
                                )
×
NEW
123
                                continue
×
124
                        }
NEW
125
                        keys[arg.queryArgKey] = len(queryArgs)
×
NEW
126

×
NEW
127
                        name := v.getNameString(arg.expr)
×
NEW
128
                        index := names[name]
×
NEW
129
                        names[name] = index + 1
×
NEW
130
                        if index > 0 {
×
NEW
131
                                name = fmt.Sprintf("%s_%d", name, index+1)
×
NEW
132
                        }
×
133

NEW
134
                        queryArgs = append(queryArgs, drivers.QueryArg{
×
NEW
135
                                Col: drivers.QueryCol{
×
NEW
136
                                        Name:     name,
×
NEW
137
                                        Nullable: omit.From(arg.Type.Nullable()),
×
NEW
138
                                        TypeName: v.getDBType(arg).Type(p.db),
×
NEW
139
                                }.Merge(arg.config),
×
NEW
140
                                Positions:     [][2]int{arg.EditedPosition},
×
NEW
141
                                CanBeMultiple: arg.CanBeMultiple,
×
NEW
142
                        })
×
143
                }
144

NEW
145
                name, configStr, _ := strings.Cut(info.comment, " ")
×
NEW
146
                queries[i] = drivers.Query{
×
NEW
147
                        Name: name,
×
NEW
148
                        SQL:  formatted,
×
NEW
149
                        Type: info.queryType,
×
NEW
150

×
NEW
151
                        Config: drivers.QueryConfig{
×
NEW
152
                                RowName:      info.comment + "Row",
×
NEW
153
                                RowSliceName: "",
×
NEW
154
                                GenerateRow:  true,
×
NEW
155
                        }.Merge(drivers.ParseQueryConfig(configStr)),
×
NEW
156

×
NEW
157
                        Columns: cols,
×
NEW
158
                        Args:    groupArgs(queryArgs),
×
NEW
159
                        Mods:    stmtToMod{info},
×
NEW
160
                }
×
161
        }
162

NEW
163
        return queries, nil
×
164
}
165

166
func (Parser) parse(v *visitor, input string) ([]stmtInfo, error) {
2✔
167
        el := &errorListener{}
2✔
168

2✔
169
        // Get all hidden tokens (usually comments) and add edit rules to remove them
2✔
170
        v.baseRules = []internal.EditRule{}
2✔
171
        hiddenLexer := sqliteparser.NewSQLiteLexer(antlr.NewInputStream(input))
2✔
172
        hiddenStream := antlr.NewCommonTokenStream(hiddenLexer, 1)
2✔
173
        hiddenStream.Fill()
2✔
174
        for _, token := range hiddenStream.GetAllTokens() {
482✔
175
                switch token.GetTokenType() {
480✔
176
                case sqliteparser.SQLiteParserSINGLE_LINE_COMMENT,
177
                        sqliteparser.SQLiteParserMULTILINE_COMMENT:
6✔
178
                        v.baseRules = append(
6✔
179
                                v.baseRules,
6✔
180
                                internal.Delete(token.GetStart(), token.GetStop()),
6✔
181
                        )
6✔
182
                }
183
        }
184

185
        // Get the regular tokens (usually the SQL statement)
186
        lexer := sqliteparser.NewSQLiteLexer(antlr.NewInputStream(input))
2✔
187
        stream := antlr.NewCommonTokenStream(lexer, 0)
2✔
188
        sqlParser := sqliteparser.NewSQLiteParser(stream)
2✔
189
        sqlParser.AddErrorListener(el)
2✔
190

2✔
191
        tree := sqlParser.Parse()
2✔
192
        if el.err != "" {
2✔
NEW
193
                return nil, errors.New(el.err)
×
NEW
194
        }
×
195

196
        infos, ok := tree.Accept(v).([]stmtInfo)
2✔
197
        if v.err != nil {
2✔
NEW
198
                return nil, fmt.Errorf("visitor: %w", v.err)
×
NEW
199
        }
×
200

201
        if !ok {
2✔
NEW
202
                return nil, fmt.Errorf("visitor: expected stmtInfo, got %T", infos)
×
NEW
203
        }
×
204

205
        return infos, nil
2✔
206
}
207

208
type stmtToMod struct {
209
        info stmtInfo
210
}
211

NEW
212
func (s stmtToMod) IncludeInTemplate(i drivers.Importer) string {
×
NEW
213
        for _, im := range s.info.imports {
×
NEW
214
                i.Import(im...)
×
NEW
215
        }
×
NEW
216
        return s.info.mods.String()
×
217
}
218

219
type errorListener struct {
220
        *antlr.DefaultErrorListener
221

222
        err string
223
}
224

NEW
225
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
×
NEW
226
        el.err = msg
×
NEW
227
}
×
228

NEW
229
func groupArgs(args []drivers.QueryArg) []drivers.QueryArg {
×
NEW
230
        newArgs := make([]drivers.QueryArg, 0, len(args))
×
NEW
231

×
NEW
232
Outer:
×
NEW
233
        for i, arg := range args {
×
NEW
234
                if len(arg.Positions) != 1 {
×
NEW
235
                        newArgs = append(newArgs, args[i])
×
NEW
236
                        continue
×
237
                }
238

NEW
239
                for j, arg2 := range args {
×
NEW
240
                        if i == j {
×
NEW
241
                                continue
×
242
                        }
243

NEW
244
                        if len(arg2.Positions) != 1 {
×
NEW
245
                                continue
×
246
                        }
247

NEW
248
                        if arg2.Positions[0][0] <= arg.Positions[0][0] &&
×
NEW
249
                                arg2.Positions[0][1] >= arg.Positions[0][1] {
×
NEW
250
                                // arg2 is a parent of arg
×
NEW
251
                                // since arg1 has a parent, it should be skipped
×
NEW
252
                                continue Outer
×
253
                        }
254

NEW
255
                        if arg.Positions[0][0] <= arg2.Positions[0][0] &&
×
NEW
256
                                arg.Positions[0][1] >= arg2.Positions[0][1] {
×
NEW
257
                                // arg is a parent of arg2
×
NEW
258
                                args[i].Children = append(args[i].Children, arg2)
×
NEW
259
                        }
×
260
                }
261

NEW
262
                newArgs = append(newArgs, args[i])
×
263
        }
264

NEW
265
        return newArgs
×
266
}
267

268
//nolint:gochecknoglobals
269
var defaultFunctions = functions{
270
        "abs": {
271
                requiredArgs: 1,
272
                args:         []string{""},
NEW
273
                calcReturnType: func(args ...string) string {
×
NEW
274
                        if args[0] == "INTEGER" {
×
NEW
275
                                return "INTEGER"
×
NEW
276
                        }
×
NEW
277
                        return "REAL"
×
278
                },
279
        },
280
        "changes": {
281
                returnType: "INTEGER",
282
        },
283
        "char": {
284
                requiredArgs: 1,
285
                variadic:     true,
286
                args:         []string{"INTEGER"},
287
                returnType:   "TEXT",
288
        },
289
        "coalesce": {
290
                requiredArgs:         1,
291
                variadic:             true,
292
                args:                 []string{""},
293
                shouldArgsBeNullable: true,
NEW
294
                calcReturnType: func(args ...string) string {
×
NEW
295
                        for _, arg := range args {
×
NEW
296
                                if arg != "" {
×
NEW
297
                                        return arg
×
NEW
298
                                }
×
299
                        }
NEW
300
                        return ""
×
301
                },
302
                calcNullable: allNullable,
303
        },
304
        "concat": {
305
                requiredArgs: 1,
306
                variadic:     true,
307
                args:         []string{"TEXT"},
308
                returnType:   "TEXT",
309
                calcNullable: neverNullable,
310
        },
311
        "concat_ws": {
312
                requiredArgs: 2,
313
                variadic:     true,
314
                args:         []string{"TEXT", "TEXT"},
315
                returnType:   "TEXT",
NEW
316
                calcNullable: func(args ...func() bool) func() bool {
×
NEW
317
                        return args[0]
×
NEW
318
                },
×
319
        },
320
        "format": {
321
                requiredArgs: 2,
322
                variadic:     true,
323
                args:         []string{"TEXT", ""},
324
                returnType:   "TEXT",
NEW
325
                calcNullable: func(args ...func() bool) func() bool {
×
NEW
326
                        return args[0]
×
NEW
327
                },
×
328
        },
329
        "glob": {
330
                requiredArgs: 2,
331
                args:         []string{"TEXT", "TEXT"},
332
                returnType:   "BOOLEAN",
333
        },
334
        "hex": {
335
                requiredArgs: 1,
336
                args:         []string{""},
337
                returnType:   "TEXT",
338
        },
339
        "ifnull": {
340
                requiredArgs: 2,
341
                args:         []string{""},
NEW
342
                calcReturnType: func(args ...string) string {
×
NEW
343
                        for _, arg := range args {
×
NEW
344
                                if arg != "" {
×
NEW
345
                                        return arg
×
NEW
346
                                }
×
347
                        }
NEW
348
                        return ""
×
349
                },
350
                calcNullable: allNullable,
351
        },
352
        "iif": {
353
                requiredArgs: 3,
354
                args:         []string{"BOOLEAN", "", ""},
NEW
355
                calcReturnType: func(args ...string) string {
×
NEW
356
                        return args[1]
×
NEW
357
                },
×
NEW
358
                calcNullable: func(args ...func() bool) func() bool {
×
NEW
359
                        return anyNullable(args[1], args[2])
×
NEW
360
                },
×
361
        },
362
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc