• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stephenafamo / bob / 14640657403

24 Apr 2025 11:38AM UTC coverage: 44.274% (-3.6%) from 47.872%
14640657403

push

github

web-flow
Merge pull request #391 from stephenafamo/queries

Implement parsing of Postgres SELECT queries

66 of 1446 new or added lines in 23 files covered. (4.56%)

4 existing lines in 3 files now uncovered.

7481 of 16897 relevant lines covered (44.27%)

222.34 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.47
/gen/bobgen-sqlite/driver/parser/parse.go
1
package parser
2

3
import (
4
        "context"
5
        "errors"
6
        "fmt"
7
        "strings"
8

9
        "github.com/aarondl/opt/omit"
10
        "github.com/antlr4-go/antlr/v4"
11
        "github.com/stephenafamo/bob/gen/drivers"
12
        "github.com/stephenafamo/bob/internal"
13
        sqliteparser "github.com/stephenafamo/sqlparser/sqlite"
14
)
15

16
func New(t tables) Parser {
28✔
17
        return Parser{db: t}
28✔
18
}
28✔
19

20
type Parser struct {
21
        db tables
22
}
23

NEW
24
func (p Parser) ParseQueries(_ context.Context, s string) ([]drivers.Query, error) {
×
UNCOV
25
        v := NewVisitor(p.db)
×
26
        infos, err := p.parse(v, s)
×
27
        if err != nil {
×
28
                return nil, fmt.Errorf("parse: %w", err)
×
29
        }
×
30

31
        queries := make([]drivers.Query, len(infos))
×
32
        for i, info := range infos {
×
33
                stmtStart := info.stmt.GetStart().GetStart()
×
34
                stmtStop := info.stmt.GetStop().GetStop()
×
35
                formatted, err := internal.EditStringSegment(s, stmtStart, stmtStop, info.editRules...)
×
36
                if err != nil {
×
37
                        return nil, fmt.Errorf("format: %w", err)
×
38
                }
×
39

40
                cols := make([]drivers.QueryCol, len(info.columns))
×
41
                for i, col := range info.columns {
×
42
                        cols[i] = drivers.QueryCol{
×
43
                                Name:     col.name,
×
44
                                DBName:   col.name,
×
45
                                Nullable: omit.From(col.typ.Nullable()),
×
46
                                TypeName: col.typ.Type(p.db),
×
47
                        }.Merge(col.config)
×
48
                }
×
49

50
                args := v.getArgs(stmtStart, stmtStop)
×
51
                keys := make(map[string]int, len(args))
×
52
                names := make(map[string]int, len(args))
×
53
                queryArgs := make([]drivers.QueryArg, 0, len(args))
×
54
                for _, arg := range args {
×
55
                        // If the key is already in the map, append the position
×
56
                        key := arg.queryArgKey
×
57
                        if oldIndex, ok := keys[key]; ok && key != "" {
×
58
                                queryArgs[oldIndex].Positions = append(
×
59
                                        queryArgs[oldIndex].Positions, arg.EditedPosition,
×
60
                                )
×
61
                                continue
×
62
                        }
63
                        keys[arg.queryArgKey] = len(queryArgs)
×
64

×
65
                        name := v.getNameString(arg.expr)
×
66
                        index := names[name]
×
67
                        names[name] = index + 1
×
68
                        if index > 0 {
×
69
                                name = fmt.Sprintf("%s_%d", name, index+1)
×
70
                        }
×
71

72
                        queryArgs = append(queryArgs, drivers.QueryArg{
×
73
                                Col: drivers.QueryCol{
×
74
                                        Name:     name,
×
75
                                        Nullable: omit.From(arg.Type.Nullable()),
×
76
                                        TypeName: v.getDBType(arg).Type(p.db),
×
77
                                }.Merge(arg.config),
×
78
                                Positions:     [][2]int{arg.EditedPosition},
×
79
                                CanBeMultiple: arg.CanBeMultiple,
×
80
                        })
×
81
                }
82

83
                name, configStr, _ := strings.Cut(info.comment, " ")
×
84
                queries[i] = drivers.Query{
×
85
                        Name: name,
×
86
                        SQL:  formatted,
×
87
                        Type: info.queryType,
×
88

×
89
                        Config: drivers.QueryConfig{
×
90
                                RowName:      info.comment + "Row",
×
91
                                RowSliceName: "",
×
92
                                GenerateRow:  true,
×
93
                        }.Merge(drivers.ParseQueryConfig(configStr)),
×
94

×
95
                        Columns: cols,
×
96
                        Args:    groupArgs(queryArgs),
×
97
                        Mods:    stmtToMod{info},
×
98
                }
×
99
        }
100

101
        return queries, nil
×
102
}
103

104
func (Parser) parse(v *visitor, input string) ([]stmtInfo, error) {
2✔
105
        el := &errorListener{}
2✔
106

2✔
107
        // Get all hidden tokens (usually comments) and add edit rules to remove them
2✔
108
        v.baseRules = []internal.EditRule{}
2✔
109
        hiddenLexer := sqliteparser.NewSQLiteLexer(antlr.NewInputStream(input))
2✔
110
        hiddenStream := antlr.NewCommonTokenStream(hiddenLexer, 1)
2✔
111
        hiddenStream.Fill()
2✔
112
        for _, token := range hiddenStream.GetAllTokens() {
482✔
113
                switch token.GetTokenType() {
480✔
114
                case sqliteparser.SQLiteParserSINGLE_LINE_COMMENT,
115
                        sqliteparser.SQLiteParserMULTILINE_COMMENT:
6✔
116
                        v.baseRules = append(
6✔
117
                                v.baseRules,
6✔
118
                                internal.Delete(token.GetStart(), token.GetStop()),
6✔
119
                        )
6✔
120
                }
121
        }
122

123
        // Get the regular tokens (usually the SQL statement)
124
        lexer := sqliteparser.NewSQLiteLexer(antlr.NewInputStream(input))
2✔
125
        stream := antlr.NewCommonTokenStream(lexer, 0)
2✔
126
        sqlParser := sqliteparser.NewSQLiteParser(stream)
2✔
127
        sqlParser.AddErrorListener(el)
2✔
128

2✔
129
        tree := sqlParser.Parse()
2✔
130
        if el.err != "" {
2✔
131
                return nil, errors.New(el.err)
×
132
        }
×
133

134
        infos, ok := tree.Accept(v).([]stmtInfo)
2✔
135
        if v.err != nil {
2✔
136
                return nil, fmt.Errorf("visitor: %w", v.err)
×
137
        }
×
138

139
        if !ok {
2✔
140
                return nil, fmt.Errorf("visitor: expected stmtInfo, got %T", infos)
×
141
        }
×
142

143
        return infos, nil
2✔
144
}
145

146
type stmtToMod struct {
147
        info stmtInfo
148
}
149

150
func (s stmtToMod) IncludeInTemplate(i drivers.Importer) string {
×
151
        for _, im := range s.info.imports {
×
152
                i.Import(im...)
×
153
        }
×
154
        return s.info.mods.String()
×
155
}
156

157
type errorListener struct {
158
        *antlr.DefaultErrorListener
159

160
        err string
161
}
162

163
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
×
164
        el.err = msg
×
165
}
×
166

167
func groupArgs(args []drivers.QueryArg) []drivers.QueryArg {
×
168
        newArgs := make([]drivers.QueryArg, 0, len(args))
×
169

×
170
Outer:
×
171
        for i, arg := range args {
×
172
                if len(arg.Positions) != 1 {
×
173
                        newArgs = append(newArgs, args[i])
×
174
                        continue
×
175
                }
176

177
                for j, arg2 := range args {
×
178
                        if i == j {
×
179
                                continue
×
180
                        }
181

182
                        if len(arg2.Positions) != 1 {
×
183
                                continue
×
184
                        }
185

186
                        if arg2.Positions[0][0] <= arg.Positions[0][0] &&
×
187
                                arg2.Positions[0][1] >= arg.Positions[0][1] {
×
188
                                // arg2 is a parent of arg
×
189
                                // since arg1 has a parent, it should be skipped
×
190
                                continue Outer
×
191
                        }
192

193
                        if arg.Positions[0][0] <= arg2.Positions[0][0] &&
×
194
                                arg.Positions[0][1] >= arg2.Positions[0][1] {
×
195
                                // arg is a parent of arg2
×
196
                                args[i].Children = append(args[i].Children, arg2)
×
197
                        }
×
198
                }
199

200
                newArgs = append(newArgs, args[i])
×
201
        }
202

203
        return newArgs
×
204
}
205

206
//nolint:gochecknoglobals
207
var defaultFunctions = functions{
208
        "abs": {
209
                requiredArgs: 1,
210
                args:         []string{""},
211
                calcReturnType: func(args ...string) string {
×
212
                        if args[0] == "INTEGER" {
×
213
                                return "INTEGER"
×
214
                        }
×
215
                        return "REAL"
×
216
                },
217
        },
218
        "changes": {
219
                returnType: "INTEGER",
220
        },
221
        "char": {
222
                requiredArgs: 1,
223
                variadic:     true,
224
                args:         []string{"INTEGER"},
225
                returnType:   "TEXT",
226
        },
227
        "coalesce": {
228
                requiredArgs:         1,
229
                variadic:             true,
230
                args:                 []string{""},
231
                shouldArgsBeNullable: true,
232
                calcReturnType: func(args ...string) string {
×
233
                        for _, arg := range args {
×
234
                                if arg != "" {
×
235
                                        return arg
×
236
                                }
×
237
                        }
238
                        return ""
×
239
                },
240
                calcNullable: allNullable,
241
        },
242
        "concat": {
243
                requiredArgs: 1,
244
                variadic:     true,
245
                args:         []string{"TEXT"},
246
                returnType:   "TEXT",
247
                calcNullable: neverNullable,
248
        },
249
        "concat_ws": {
250
                requiredArgs: 2,
251
                variadic:     true,
252
                args:         []string{"TEXT", "TEXT"},
253
                returnType:   "TEXT",
254
                calcNullable: func(args ...func() bool) func() bool {
×
255
                        return args[0]
×
256
                },
×
257
        },
258
        "format": {
259
                requiredArgs: 2,
260
                variadic:     true,
261
                args:         []string{"TEXT", ""},
262
                returnType:   "TEXT",
263
                calcNullable: func(args ...func() bool) func() bool {
×
264
                        return args[0]
×
265
                },
×
266
        },
267
        "glob": {
268
                requiredArgs: 2,
269
                args:         []string{"TEXT", "TEXT"},
270
                returnType:   "BOOLEAN",
271
        },
272
        "hex": {
273
                requiredArgs: 1,
274
                args:         []string{""},
275
                returnType:   "TEXT",
276
        },
277
        "ifnull": {
278
                requiredArgs: 2,
279
                args:         []string{""},
280
                calcReturnType: func(args ...string) string {
×
281
                        for _, arg := range args {
×
282
                                if arg != "" {
×
283
                                        return arg
×
284
                                }
×
285
                        }
286
                        return ""
×
287
                },
288
                calcNullable: allNullable,
289
        },
290
        "iif": {
291
                requiredArgs: 3,
292
                args:         []string{"BOOLEAN", "", ""},
293
                calcReturnType: func(args ...string) string {
×
294
                        return args[1]
×
295
                },
×
296
                calcNullable: func(args ...func() bool) func() bool {
×
297
                        return anyNullable(args[1], args[2])
×
298
                },
×
299
        },
300
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc