• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bokwoon95 / sqddl / 10872614510

15 Sep 2024 04:33PM UTC coverage: 75.203% (+0.4%) from 74.831%
10872614510

push

github

bokwoon95
make StructParser detect struct{} fields. fixes #8

5 of 5 new or added lines in 1 file covered. (100.0%)

2 existing lines in 1 file now uncovered.

8231 of 10945 relevant lines covered (75.2%)

7195.24 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.51
/ddl/struct_parser.go
1
package ddl
2

3
import (
4
        "bytes"
5
        "fmt"
6
        "go/ast"
7
        "go/parser"
8
        "go/token"
9
        "io/fs"
10
        "reflect"
11
        "strconv"
12
        "strings"
13
        "unicode"
14

15
        "golang.org/x/tools/go/analysis"
16
        "golang.org/x/tools/go/analysis/passes/inspect"
17
        "golang.org/x/tools/go/ast/inspector"
18
)
19

20
// StructParser is used to parse Go source code into TableStructs.
21
type StructParser struct {
22
        TableStructs       TableStructs
23
        parserDiagnostics  *parserDiagnostics
24
        dialect            string
25
        locations          map[[2]string]location
26
        columnExplicitType map[[3]string]struct{}
27
        cache              *CatalogCache
28
}
29

30
// NewStructParser creates a new StructParser. An existing token.Fileset can be
31
// passed in. If not, passing in nil is fine and a new token.FileSet will be
32
// instantiated.
33
func NewStructParser(fset *token.FileSet) *StructParser {
233✔
34
        if fset == nil {
466✔
35
                fset = token.NewFileSet()
233✔
36
        }
233✔
37
        return &StructParser{parserDiagnostics: &parserDiagnostics{
233✔
38
                fset: fset,
233✔
39
        }}
233✔
40
}
41

42
// VisitNode is a callback function that populates the TableStructs when passed
43
// to ast.Inspect().
44
func (p *StructParser) VisitNode(node ast.Node) bool {
3,033✔
45
        switch node.(type) {
3,033✔
46
        case *ast.File, *ast.GenDecl:
1,011✔
47
                return true
1,011✔
48
        }
49
        p.VisitStruct(node)
2,022✔
50
        return false
2,022✔
51
}
52

53
// VisitStruct is a callback function that populates the TableStructs when
54
// passed to inspect.Inspector.Preorder(). It expects the node to be of type
55
// *ast.TypeSpec.
56
func (p *StructParser) VisitStruct(node ast.Node) {
2,022✔
57
        // Is it a type declaration?
2,022✔
58
        typeSpec, ok := node.(*ast.TypeSpec)
2,022✔
59
        if !ok {
3,479✔
60
                return
1,457✔
61
        }
1,457✔
62
        // Is it a type declaration for a struct?
63
        structType, ok := typeSpec.Type.(*ast.StructType)
565✔
64
        if !ok {
565✔
65
                return
×
66
        }
×
67
        // Does the struct have fields?
68
        if structType.Fields == nil || structType.Fields.List == nil {
565✔
69
                return
×
70
        }
×
71
        tableStruct := TableStruct{
565✔
72
                Name:   typeSpec.Name.Name,
565✔
73
                Fields: make([]StructField, 0, len(structType.Fields.List)),
565✔
74
        }
565✔
75
        for i, astField := range structType.Fields.List {
3,126✔
76
                var structField StructField
2,561✔
77
                // Name
2,561✔
78
                if len(astField.Names) > 0 && astField.Names[0] != nil {
4,557✔
79
                        structField.Name = astField.Names[0].Name
1,996✔
80
                }
1,996✔
81
                // Type
82
                if typ, ok := astField.Type.(*ast.SelectorExpr); ok {
5,097✔
83
                        if x, ok := typ.X.(*ast.Ident); ok {
5,072✔
84
                                structField.Type = x.Name + "." + typ.Sel.Name
2,536✔
85
                        }
2,536✔
86
                } else if typ, ok := astField.Type.(*ast.Ident); ok {
25✔
87
                        structField.Type = typ.Name
×
88
                } else if typ, ok := astField.Type.(*ast.StructType); ok {
50✔
89
                        if typ.Fields == nil || typ.Fields.List == nil {
50✔
90
                                structField.Type = "struct{}"
25✔
91
                        }
25✔
92
                }
93
                // Tag
94
                if astField.Tag != nil {
4,351✔
95
                        structField.tagPos = astField.Tag.Pos()
1,790✔
96
                        if tag, err := strconv.Unquote(astField.Tag.Value); err == nil {
3,580✔
97
                                structField.NameTag = reflect.StructTag(tag).Get("sq")
1,790✔
98
                                structField.Modifiers, err = NewModifiers(reflect.StructTag(tag).Get("ddl"))
1,790✔
99
                                if err != nil {
1,790✔
100
                                        loc := location{
×
101
                                                pos:        structField.tagPos,
×
102
                                                structName: tableStruct.Name,
×
103
                                                fieldName:  structField.Name,
×
104
                                        }
×
105
                                        p.report(loc, err.Error())
×
106
                                        continue
×
107
                                }
108
                        }
109
                }
110
                // If the first field is not sq.TableStruct, skip this struct entirely.
111
                if i == 0 {
3,126✔
112
                        if structField.Name != "" || structField.Type != "sq.TableStruct" {
565✔
113
                                structNameIsUppercase := true
×
114
                                for _, char := range tableStruct.Name {
×
115
                                        if !unicode.IsUpper(char) {
×
116
                                                structNameIsUppercase = false
×
117
                                                break
×
118
                                        }
119
                                }
120
                                if structNameIsUppercase && structField.Type != "sq.TableStruct" {
×
121
                                        loc := location{pos: structField.tagPos}
×
122
                                        p.report(loc, "struct "+tableStruct.Name+" is all uppercase but no sq.TableStruct field was found")
×
123
                                }
×
124
                                return
×
125
                        }
126
                }
127
                tableStruct.Fields = append(tableStruct.Fields, structField)
2,561✔
128
        }
129
        p.TableStructs = append(p.TableStructs, tableStruct)
565✔
130
}
131

132
// ParseFile parses an fs.File containing Go source code and populates the
133
// TableStructs.
134
func (p *StructParser) ParseFile(f fs.File) error {
233✔
135
        fileinfo, err := f.Stat()
233✔
136
        if err != nil {
233✔
137
                return err
×
138
        }
×
139
        file, err := parser.ParseFile(p.parserDiagnostics.fset, fileinfo.Name(), f, 0)
233✔
140
        if err != nil {
233✔
141
                return err
×
142
        }
×
143
        ast.Inspect(file, p.VisitNode)
233✔
144
        return p.Error()
233✔
145
}
146

147
// WriteCatalog populates the Catalog using the StructParser's TableStructs.
148
func (p *StructParser) WriteCatalog(catalog *Catalog) error {
233✔
149
        p.dialect = catalog.Dialect
233✔
150
        p.locations = make(map[[2]string]location)
233✔
151
        p.columnExplicitType = make(map[[3]string]struct{})
233✔
152
        p.cache = NewCatalogCache(catalog)
233✔
153

233✔
154
        for _, tableStruct := range p.TableStructs {
798✔
155
                if len(tableStruct.Fields) == 0 {
565✔
156
                        continue
×
157
                }
158
                var tableSchema string
565✔
159
                tableName := strings.ToLower(tableStruct.Name)
565✔
160
                if tableStruct.Fields[0].NameTag != "" {
678✔
161
                        tableName = tableStruct.Fields[0].NameTag
113✔
162
                }
113✔
163
                if i := strings.IndexByte(tableName, '.'); i >= 0 {
678✔
164
                        tableSchema, tableName = tableName[:i], tableName[i+1:]
113✔
165
                }
113✔
166
                if tableSchema == "" && catalog.CurrentSchema != "" {
776✔
167
                        tableSchema = catalog.CurrentSchema
211✔
168
                }
211✔
169

170
                schema := p.cache.GetOrCreateSchema(catalog, tableSchema)
565✔
171
                table := p.cache.GetOrCreateTable(schema, tableName)
565✔
172

565✔
173
                // The main loop.
565✔
174
                for _, structField := range tableStruct.Fields {
3,126✔
175
                        loc := location{
2,561✔
176
                                pos:        structField.tagPos,
2,561✔
177
                                structName: tableStruct.Name,
2,561✔
178
                                fieldName:  structField.Name,
2,561✔
179
                        }
2,561✔
180
                        if (structField.Name == "" && structField.Type == "sq.TableStruct") || (structField.Name == "_" && structField.Type == "struct{}") {
3,151✔
181
                                p.parseTableModifiers(table, loc, structField.Modifiers)
590✔
182
                                continue
590✔
183
                        }
184
                        columnName := strings.ToLower(structField.Name)
1,971✔
185
                        if structField.NameTag != "" {
1,971✔
186
                                columnName = structField.NameTag
×
187
                        }
×
188
                        var columnType, characterLength string
1,971✔
189
                        switch structField.Type {
1,971✔
190
                        case "sq.AnyField":
54✔
191
                        case "sq.ArrayField":
×
192
                                switch p.dialect {
×
193
                                case DialectSQLite, DialectMySQL:
×
194
                                        columnType = "JSON"
×
195
                                case DialectPostgres:
×
196
                                        columnType = "TEXT[]"
×
197
                                case DialectSQLServer:
×
198
                                        columnType, characterLength = "NVARCHAR(MAX)", "MAX"
×
199
                                default:
×
200
                                        columnType, characterLength = "VARCHAR(255)", "255"
×
201
                                }
202
                        case "sq.BinaryField":
1✔
203
                                switch p.dialect {
1✔
204
                                case DialectSQLite:
1✔
205
                                        columnType = "BLOB"
1✔
206
                                case DialectPostgres:
×
207
                                        columnType = "BYTEA"
×
208
                                case DialectMySQL:
×
209
                                        columnType = "MEDIUMBLOB"
×
210
                                case DialectSQLServer:
×
211
                                        columnType, characterLength = "VARBINARY(MAX)", "MAX"
×
212
                                default:
×
213
                                        columnType = "BINARY"
×
214
                                }
215
                        case "sq.BooleanField":
5✔
216
                                switch p.dialect {
5✔
217
                                case DialectSQLServer:
×
218
                                        columnType = "BIT"
×
219
                                default:
5✔
220
                                        columnType = "BOOLEAN"
5✔
221
                                }
222
                        case "sq.EnumField":
×
223
                                switch p.dialect {
×
224
                                case DialectSQLite, DialectPostgres:
×
225
                                        columnType = "TEXT"
×
226
                                case DialectSQLServer:
×
227
                                        columnType, characterLength = "NVARCHAR(255)", "255"
×
228
                                default:
×
229
                                        columnType, characterLength = "VARCHAR(255)", "255"
×
230
                                }
231
                        case "sq.JSONField":
32✔
232
                                switch p.dialect {
32✔
233
                                case DialectSQLite, DialectMySQL:
12✔
234
                                        columnType = "JSON"
12✔
235
                                case DialectPostgres:
10✔
236
                                        columnType = "JSONB"
10✔
237
                                case DialectSQLServer:
10✔
238
                                        columnType, characterLength = "NVARCHAR(MAX)", "MAX"
10✔
239
                                default:
×
240
                                        columnType = "VARCHAR(255)"
×
241
                                }
242
                        case "sq.NumberField":
944✔
243
                                columnType = "INT"
944✔
244
                        case "sq.StringField":
868✔
245
                                switch p.dialect {
868✔
246
                                case DialectSQLite, DialectPostgres:
396✔
247
                                        columnType = "TEXT"
396✔
248
                                case DialectSQLServer:
235✔
249
                                        columnType, characterLength = "NVARCHAR(255)", "255"
235✔
250
                                default:
237✔
251
                                        columnType, characterLength = "VARCHAR(255)", "255"
237✔
252
                                }
253
                        case "sq.TimeField":
59✔
254
                                switch p.dialect {
59✔
255
                                case DialectPostgres:
5✔
256
                                        columnType = "TIMESTAMPTZ"
5✔
257
                                case DialectSQLServer:
5✔
258
                                        columnType = "DATETIMEOFFSET"
5✔
259
                                default:
49✔
260
                                        columnType = "DATETIME"
49✔
261
                                }
262
                        case "sq.UUIDField":
8✔
263
                                switch p.dialect {
8✔
264
                                case DialectSQLite, DialectPostgres:
8✔
265
                                        columnType = "UUID"
8✔
266
                                default:
×
267
                                        columnType = "BINARY(16)"
×
268
                                }
UNCOV
269
                        default:
×
UNCOV
270
                                continue
×
271
                        }
272
                        if characterLength != "" {
2,453✔
273
                                column := p.cache.GetOrCreateColumn(table, columnName, columnType)
482✔
274
                                column.CharacterLength = characterLength
482✔
275
                        }
482✔
276
                        p.parseColumnModifiers(table, columnName, columnType, loc, structField.Modifiers)
1,971✔
277
                }
278

279
                // Validate column existence for PRIMARY KEY and UNIQUE constraints.
280
                for _, constraint := range table.Constraints {
1,498✔
281
                        if constraint.Ignore {
933✔
282
                                continue
×
283
                        }
284
                        hasInvalidColumn := false
933✔
285
                        for _, columnName := range constraint.Columns {
1,872✔
286
                                column := p.cache.GetColumn(table, columnName)
939✔
287
                                if column == nil {
939✔
288
                                        hasInvalidColumn = true
×
289
                                        loc := p.locations[[2]string{table.TableSchema, constraint.ConstraintName}]
×
290
                                        p.report(loc, strings.Join(constraint.Columns, ",")+": "+columnName+" does not exist in the table")
×
291
                                }
×
292
                        }
293
                        if hasInvalidColumn {
933✔
294
                                continue
×
295
                        }
296
                        // set IsUnique and IsPrimaryKey for the corresponding columns
297
                        if len(constraint.Columns) != 1 || (constraint.ConstraintType != PRIMARY_KEY && constraint.ConstraintType != UNIQUE && constraint.ConstraintType != FOREIGN_KEY) {
938✔
298
                                continue
5✔
299
                        }
300
                        columnName := constraint.Columns[0]
928✔
301
                        column := p.cache.GetColumn(table, columnName)
928✔
302
                        switch constraint.ConstraintType {
928✔
303
                        case PRIMARY_KEY:
379✔
304
                                column.IsPrimaryKey = true
379✔
305
                                if catalog.Dialect == DialectSQLite && strings.EqualFold(column.ColumnType, "INT") {
447✔
306
                                        if _, ok := p.columnExplicitType[[3]string{table.TableSchema, table.TableName, columnName}]; !ok {
136✔
307
                                                column.ColumnType = "INTEGER"
68✔
308
                                        }
68✔
309
                                }
310
                        case UNIQUE:
52✔
311
                                column.IsUnique = true
52✔
312
                        case FOREIGN_KEY:
497✔
313
                                column.ReferencesSchema = constraint.ReferencesSchema
497✔
314
                                column.ReferencesTable = constraint.ReferencesTable
497✔
315
                                column.ReferencesColumn = constraint.ReferencesColumns[0]
497✔
316
                                column.UpdateRule = constraint.UpdateRule
497✔
317
                                column.DeleteRule = constraint.DeleteRule
497✔
318
                                column.IsDeferrable = constraint.IsDeferrable
497✔
319
                                column.IsInitiallyDeferred = constraint.IsInitiallyDeferred
497✔
320
                        }
321
                }
322

323
                // Validate column existence for indexes.
324
                for _, index := range table.Indexes {
774✔
325
                        if index.Ignore {
209✔
326
                                continue
×
327
                        }
328
                        for _, columnName := range index.Columns {
443✔
329
                                column := p.cache.GetColumn(table, columnName)
234✔
330
                                if column == nil {
234✔
331
                                        loc := p.locations[[2]string{table.TableSchema, index.IndexName}]
×
332
                                        p.report(loc, strings.Join(index.Columns, ",")+": "+columnName+" does not exist in the table")
×
333
                                }
×
334
                        }
335
                }
336

337
                // Set PRIMARY KEY columns to NOT NULL.
338
                pkey := p.cache.GetPrimaryKey(table)
565✔
339
                if pkey != nil && !pkey.Ignore {
947✔
340
                        for _, columnName := range pkey.Columns {
767✔
341
                                column := p.cache.GetColumn(table, columnName)
385✔
342
                                if column == nil {
385✔
343
                                        continue
×
344
                                }
345
                                if catalog.Dialect == DialectSQLite && column.ColumnType == "INTEGER" && column.IsPrimaryKey {
463✔
346
                                        // SQLite forbids INTEGER PRIMARY KEY (alias for ROWID) columns
78✔
347
                                        // from being marked as NOT NULL (since they can never be NULL).
78✔
348
                                        continue
78✔
349
                                }
350
                                column.IsNotNull = true
307✔
351
                        }
352
                }
353
        }
354

355
        // Validate column existence for FOREIGN KEY constraints.
356
        for _, schema := range catalog.Schemas {
477✔
357
                for _, table := range schema.Tables {
809✔
358
                        for _, constraint := range table.Constraints {
1,498✔
359
                                if constraint.Ignore || constraint.ConstraintType != FOREIGN_KEY {
1,368✔
360
                                        continue
435✔
361
                                }
362
                                // We need the location of the failing constraint so that we
363
                                // can inform the user where it is. If we can't find it,
364
                                // continue.
365
                                loc, ok := p.locations[[2]string{table.TableSchema, constraint.ConstraintName}]
498✔
366
                                if !ok {
498✔
367
                                        continue
×
368
                                }
369
                                schemaName := constraint.ReferencesSchema
498✔
370
                                if schemaName == "" {
924✔
371
                                        schemaName = catalog.CurrentSchema
426✔
372
                                }
426✔
373
                                refschema := p.cache.GetSchema(catalog, schemaName)
498✔
374
                                if refschema == nil {
498✔
375
                                        p.report(loc, fmt.Sprintf("schema %s does not exist", schemaName))
×
376
                                        continue
×
377
                                }
378
                                reftable := p.cache.GetTable(refschema, constraint.ReferencesTable)
498✔
379
                                tableName := constraint.ReferencesTable
498✔
380
                                if reftable == nil {
498✔
381
                                        if schemaName != "" {
×
382
                                                tableName = schemaName + "." + tableName
×
383
                                        }
×
384
                                        p.report(loc, fmt.Sprintf("table %s does not exist", tableName))
×
385
                                        continue
×
386
                                }
387
                                for _, columnName := range constraint.ReferencesColumns {
997✔
388
                                        refcolumn := p.cache.GetColumn(reftable, columnName)
499✔
389
                                        if refcolumn == nil {
499✔
390
                                                columnName = tableName + "." + columnName
×
391
                                                p.report(loc, fmt.Sprintf("column %s does not exist", columnName))
×
392
                                                continue
×
393
                                        }
394
                                }
395
                        }
396
                }
397
        }
398
        return p.Error()
233✔
399
}
400

401
func (p *StructParser) parseIndexModifier(table *Table, columnNames []string, loc location, m *Modifier) {
209✔
402
        err := m.ParseRawValue()
209✔
403
        if err != nil {
209✔
404
                p.report(loc, err.Error())
×
405
                return
×
406
        }
×
407
        if m.Value != "" && m.Value != "." {
232✔
408
                columnNames = strings.Split(m.Value, ",")
23✔
409
        }
23✔
410
        if len(columnNames) == 0 {
209✔
411
                p.report(loc, "no column provided")
×
412
        }
×
413
        indexName := GenerateName(INDEX, table.TableName, columnNames)
209✔
414
        p.locations[[2]string{table.TableSchema, indexName}] = loc
209✔
415
        index := p.cache.GetOrCreateIndex(table, indexName, columnNames)
209✔
416
        index.TableSchema = table.TableSchema
209✔
417
        index.TableName = table.TableName
209✔
418
        index.Ignore = m.ExcludesDialect(p.dialect)
209✔
419
        for i := range m.Submodifiers {
231✔
420
                submodifier := &m.Submodifiers[i]
22✔
421
                if submodifier.ExcludesDialect(p.dialect) {
22✔
422
                        continue
×
423
                }
424
                switch submodifier.Name {
22✔
425
                case "unique":
2✔
426
                        index.IsUnique = true
2✔
427
                case "using":
20✔
428
                        if p.dialect != DialectPostgres && p.dialect != DialectMySQL {
30✔
429
                                continue
10✔
430
                        }
431
                        index.IndexType = submodifier.RawValue
10✔
432
                default:
×
433
                        p.report(loc, "unknown modifier "+strconv.Quote(submodifier.Name))
×
434
                }
435
        }
436
}
437

438
func (p *StructParser) parsePrimaryKeyUniqueModifier(table *Table, columnNames []string, loc location, m *Modifier) {
435✔
439
        err := m.ParseRawValue()
435✔
440
        if err != nil {
435✔
441
                p.report(loc, err.Error())
×
442
                return
×
443
        }
×
444
        constraintType := PRIMARY_KEY
435✔
445
        if strings.EqualFold(UNIQUE, m.Name) {
488✔
446
                constraintType = UNIQUE
53✔
447
        }
53✔
448
        if m.Value != "" && m.Value != "." {
440✔
449
                columnNames = strings.Split(m.Value, ",")
5✔
450
        }
5✔
451
        if len(columnNames) == 0 {
435✔
452
                p.report(loc, "no column provided")
×
453
        }
×
454
        constraintName := GenerateName(constraintType, table.TableName, columnNames)
435✔
455
        if p.dialect == DialectMySQL && constraintType == PRIMARY_KEY {
530✔
456
                constraintName = "PRIMARY"
95✔
457
        }
95✔
458
        p.locations[[2]string{table.TableSchema, constraintName}] = loc
435✔
459
        constraint := p.cache.GetOrCreateConstraint(table, constraintName, constraintType, columnNames)
435✔
460
        constraint.TableSchema = table.TableSchema
435✔
461
        constraint.TableName = table.TableName
435✔
462
        constraint.Ignore = m.ExcludesDialect(p.dialect)
435✔
463
        for i := range m.Submodifiers {
435✔
464
                submodifier := &m.Submodifiers[i]
×
465
                if submodifier.ExcludesDialect(p.dialect) {
×
466
                        continue
×
467
                }
468
                switch submodifier.Name {
×
469
                case "deferrable":
×
470
                        if p.dialect != DialectSQLite && p.dialect != DialectPostgres {
×
471
                                continue
×
472
                        }
473
                        constraint.IsDeferrable = true
×
474
                case "deferred":
×
475
                        if p.dialect != DialectSQLite && p.dialect != DialectPostgres {
×
476
                                continue
×
477
                        }
478
                        constraint.IsDeferrable = true
×
479
                        constraint.IsInitiallyDeferred = true
×
480
                default:
×
481
                        p.report(loc, "unknown modifier "+strconv.Quote(submodifier.Name))
×
482
                }
483
        }
484
}
485

486
func (p *StructParser) parseForeignKeyModifier(table *Table, loc location, m *Modifier) {
1✔
487
        err := m.ParseRawValue()
1✔
488
        if err != nil {
1✔
489
                p.report(loc, err.Error())
×
490
                return
×
491
        }
×
492
        if m.Value == "" || m.Value == "." {
1✔
493
                p.report(loc, "no column(s) provided")
×
494
                return
×
495
        }
×
496
        if len(m.Submodifiers) == 0 {
1✔
497
                p.report(loc, "no referenced column(s) provided")
×
498
                return
×
499
        }
×
500
        columnNames := strings.Split(m.Value, ",")
1✔
501
        constraintName := GenerateName(FOREIGN_KEY, table.TableName, columnNames)
1✔
502
        p.locations[[2]string{table.TableSchema, constraintName}] = loc
1✔
503
        constraint := p.cache.GetOrCreateConstraint(table, constraintName, FOREIGN_KEY, columnNames)
1✔
504
        constraint.TableSchema = table.TableSchema
1✔
505
        constraint.TableName = table.TableName
1✔
506
        constraint.Ignore = m.ExcludesDialect(p.dialect)
1✔
507
        for i := range m.Submodifiers {
3✔
508
                submodifier := &m.Submodifiers[i]
2✔
509
                if submodifier.ExcludesDialect(p.dialect) {
2✔
510
                        continue
×
511
                }
512
                switch submodifier.Name {
2✔
513
                case "references":
1✔
514
                        switch parts := strings.SplitN(submodifier.RawValue, ".", 3); len(parts) {
1✔
515
                        case 1:
1✔
516
                                constraint.ReferencesTable = parts[0]
1✔
517
                                constraint.ReferencesColumns = columnNames
1✔
518
                        case 2:
×
519
                                constraint.ReferencesTable = parts[0]
×
520
                                constraint.ReferencesColumns = strings.Split(parts[1], ",")
×
521
                        case 3:
×
522
                                constraint.ReferencesSchema = parts[0]
×
523
                                constraint.ReferencesTable = parts[1]
×
524
                                constraint.ReferencesColumns = strings.Split(parts[2], ",")
×
525
                        }
526
                case "onupdate":
×
527
                        switch submodifier.RawValue {
×
528
                        case "cascade":
×
529
                                constraint.UpdateRule = CASCADE
×
530
                        case "restrict":
×
531
                                if p.dialect == DialectSQLServer {
×
532
                                        constraint.UpdateRule = NO_ACTION
×
533
                                } else {
×
534
                                        constraint.UpdateRule = RESTRICT
×
535
                                }
×
536
                        case "noaction":
×
537
                                constraint.UpdateRule = NO_ACTION
×
538
                        case "setnull":
×
539
                                constraint.UpdateRule = SET_NULL
×
540
                        case "setdefault":
×
541
                                constraint.UpdateRule = SET_DEFAULT
×
542
                        case "":
×
543
                                constraint.UpdateRule = ""
×
544
                        default:
×
545
                                loc.keys = append(loc.keys, submodifier.Name)
×
546
                                p.report(loc, "unknown value "+strconv.Quote(submodifier.RawValue))
×
547
                        }
548
                case "ondelete":
×
549
                        switch submodifier.RawValue {
×
550
                        case "cascade":
×
551
                                constraint.DeleteRule = CASCADE
×
552
                        case "restrict":
×
553
                                if p.dialect == DialectSQLServer {
×
554
                                        constraint.DeleteRule = NO_ACTION
×
555
                                } else {
×
556
                                        constraint.DeleteRule = RESTRICT
×
557
                                }
×
558
                        case "noaction":
×
559
                                constraint.DeleteRule = NO_ACTION
×
560
                        case "setnull":
×
561
                                constraint.DeleteRule = SET_NULL
×
562
                        case "setdefault":
×
563
                                constraint.DeleteRule = SET_DEFAULT
×
564
                        case "":
×
565
                                constraint.DeleteRule = ""
×
566
                        default:
×
567
                                loc.keys = append(loc.keys, submodifier.Name)
×
568
                                p.report(loc, "unknown value "+strconv.Quote(submodifier.RawValue))
×
569
                        }
570
                case "deferrable":
×
571
                        if p.dialect != DialectSQLite && p.dialect != DialectPostgres {
×
572
                                continue
×
573
                        }
574
                        constraint.IsDeferrable = true
×
575
                case "deferred":
×
576
                        if p.dialect != DialectSQLite && p.dialect != DialectPostgres {
×
577
                                continue
×
578
                        }
579
                        constraint.IsDeferrable = true
×
580
                        constraint.IsInitiallyDeferred = true
×
581
                case "index":
1✔
582
                        loc.keys = append(loc.keys, submodifier.Name)
1✔
583
                        p.parseIndexModifier(table, columnNames, loc, submodifier)
1✔
584
                default:
×
585
                        p.report(loc, "unknown modifier "+strconv.Quote(submodifier.Name))
×
586
                }
587
        }
588
}
589

590
func (p *StructParser) parseReferencesModifier(table *Table, columnName string, loc location, m *Modifier) {
497✔
591
        err := m.ParseRawValue()
497✔
592
        if err != nil {
497✔
593
                p.report(loc, err.Error())
×
594
                return
×
595
        }
×
596
        if columnName == "" {
497✔
597
                p.report(loc, "no column provided")
×
598
        }
×
599
        constraintName := GenerateName(FOREIGN_KEY, table.TableName, []string{columnName})
497✔
600
        p.locations[[2]string{table.TableSchema, constraintName}] = loc
497✔
601

497✔
602
        constraint := p.cache.GetOrCreateConstraint(table, constraintName, FOREIGN_KEY, []string{columnName})
497✔
603
        constraint.TableSchema = table.TableSchema
497✔
604
        constraint.TableName = table.TableName
497✔
605
        constraint.Ignore = m.ExcludesDialect(p.dialect)
497✔
606
        switch parts := strings.SplitN(m.Value, ".", 3); len(parts) {
497✔
607
        case 1:
82✔
608
                constraint.ReferencesTable = parts[0]
82✔
609
                constraint.ReferencesColumns = []string{columnName}
82✔
610
        case 2:
343✔
611
                constraint.ReferencesTable = parts[0]
343✔
612
                constraint.ReferencesColumns = strings.Split(parts[1], ",")
343✔
613
        case 3:
72✔
614
                constraint.ReferencesSchema = parts[0]
72✔
615
                constraint.ReferencesTable = parts[1]
72✔
616
                constraint.ReferencesColumns = strings.Split(parts[2], ",")
72✔
617
        }
618

619
        for i, submodifier := range m.Submodifiers {
607✔
620
                if submodifier.ExcludesDialect(p.dialect) {
110✔
621
                        continue
×
622
                }
623
                switch submodifier.Name {
110✔
624
                case "onupdate":
37✔
625
                        switch submodifier.RawValue {
37✔
626
                        case "cascade":
37✔
627
                                constraint.UpdateRule = CASCADE
37✔
628
                        case "restrict":
×
629
                                if p.dialect == DialectSQLServer {
×
630
                                        constraint.UpdateRule = NO_ACTION
×
631
                                } else {
×
632
                                        constraint.UpdateRule = RESTRICT
×
633
                                }
×
634
                        case "noaction":
×
635
                                constraint.UpdateRule = NO_ACTION
×
636
                        case "setnull":
×
637
                                constraint.UpdateRule = SET_NULL
×
638
                        case "setdefault":
×
639
                                constraint.UpdateRule = SET_DEFAULT
×
640
                        case "":
×
641
                                constraint.UpdateRule = ""
×
642
                        default:
×
643
                                loc.keys = append(loc.keys, submodifier.Name)
×
644
                                p.report(loc, "unknown value "+strconv.Quote(submodifier.RawValue))
×
645
                        }
646
                case "ondelete":
22✔
647
                        switch submodifier.RawValue {
22✔
648
                        case "cascade":
×
649
                                constraint.DeleteRule = CASCADE
×
650
                        case "restrict":
21✔
651
                                if p.dialect == DialectSQLServer {
21✔
652
                                        constraint.DeleteRule = NO_ACTION
×
653
                                } else {
21✔
654
                                        constraint.DeleteRule = RESTRICT
21✔
655
                                }
21✔
656
                        case "noaction":
×
657
                                constraint.DeleteRule = NO_ACTION
×
658
                        case "setnull":
1✔
659
                                constraint.DeleteRule = SET_NULL
1✔
660
                        case "setdefault":
×
661
                                constraint.DeleteRule = SET_DEFAULT
×
662
                        case "":
×
663
                                constraint.DeleteRule = ""
×
664
                        default:
×
665
                                loc.keys = append(loc.keys, submodifier.Name)
×
666
                                p.report(loc, "unknown value "+strconv.Quote(submodifier.RawValue))
×
667
                        }
668
                case "deferrable":
×
669
                        if p.dialect == DialectSQLite || p.dialect == DialectPostgres {
×
670
                                constraint.IsDeferrable = true
×
671
                        }
×
672
                case "deferred":
15✔
673
                        if p.dialect == DialectSQLite || p.dialect == DialectPostgres {
20✔
674
                                constraint.IsDeferrable = true
5✔
675
                                constraint.IsInitiallyDeferred = true
5✔
676
                        }
5✔
677
                case "index":
36✔
678
                        loc.keys = append(loc.keys, submodifier.Name)
36✔
679
                        p.parseIndexModifier(table, []string{columnName}, loc, &m.Submodifiers[i])
36✔
680
                default:
×
681
                        p.report(loc, "unknown modifier "+strconv.Quote(submodifier.Name))
×
682
                }
683
        }
684
}
685

686
func (p *StructParser) parseColumnModifiers(table *Table, columnName, columnType string, loc location, modifiers []Modifier) {
1,971✔
687
        column := p.cache.GetOrCreateColumn(table, columnName, columnType)
1,971✔
688
        column.TableSchema = table.TableSchema
1,971✔
689
        column.TableName = table.TableName
1,971✔
690
        column.IsEnum = columnType == "sq.EnumField"
1,971✔
691

1,971✔
692
        var dialects []string
1,971✔
693
        for i := range modifiers {
4,276✔
694
                modifier := &modifiers[i]
2,305✔
695
                if len(modifier.Dialects) == 0 {
4,510✔
696
                        modifier.Dialects = dialects
2,205✔
697
                }
2,205✔
698
                if modifier.ExcludesDialect(p.dialect) {
2,460✔
699
                        continue
155✔
700
                }
701
                switch modifier.Name {
2,150✔
702
                case "type":
322✔
703
                        p.columnExplicitType[[3]string{table.TableSchema, table.TableName, columnName}] = struct{}{}
322✔
704
                        column.ColumnType = modifier.RawValue
322✔
705
                        normalizedType, arg1, arg2 := normalizeColumnType(p.dialect, column.ColumnType)
322✔
706
                        switch normalizedType {
322✔
707
                        case "VARBINARY", "BINARY", "NVARCHAR", "VARCHAR", "CHAR":
130✔
708
                                if arg1 != "" {
260✔
709
                                        column.CharacterLength = arg1
130✔
710
                                } else if column.CharacterLength != "" {
130✔
711
                                        column.ColumnType = column.ColumnType + "(" + column.CharacterLength + ")"
×
712
                                }
×
713
                                column.NumericPrecision, column.NumericScale = "", ""
130✔
714
                        case "NUMERIC":
90✔
715
                                if arg1 != "" {
180✔
716
                                        column.NumericPrecision = arg1
90✔
717
                                }
90✔
718
                                if arg2 != "" {
180✔
719
                                        column.NumericScale = arg2
90✔
720
                                }
90✔
721
                                column.CharacterLength = ""
90✔
722
                        default:
102✔
723
                                column.CharacterLength, column.NumericPrecision, column.NumericScale = "", "", ""
102✔
724
                        }
725
                case "len":
60✔
726
                        column.CharacterLength = modifier.RawValue
60✔
727
                        if column.CharacterLength != "" {
120✔
728
                                switch p.dialect {
60✔
729
                                case DialectPostgres, DialectMySQL:
30✔
730
                                        column.ColumnType = "VARCHAR(" + column.CharacterLength + ")"
30✔
731
                                case DialectSQLServer:
15✔
732
                                        column.ColumnType = "NVARCHAR(" + column.CharacterLength + ")"
15✔
733
                                }
734
                        }
735
                case "auto_increment":
20✔
736
                        if p.dialect != DialectMySQL {
35✔
737
                                continue
15✔
738
                        }
739
                        column.IsAutoincrement = true
5✔
740
                case "autoincrement":
10✔
741
                        if p.dialect != DialectSQLite {
10✔
742
                                continue
×
743
                        }
744
                        column.ColumnType = "INTEGER"
10✔
745
                        column.IsAutoincrement = true
10✔
746
                case "identity":
192✔
747
                        switch p.dialect {
192✔
748
                        case DialectPostgres:
62✔
749
                                column.ColumnIdentity = DEFAULT_IDENTITY
62✔
750
                        case DialectSQLServer:
55✔
751
                                column.ColumnIdentity = IDENTITY
55✔
752
                        }
753
                case "alwaysidentity":
×
754
                        switch p.dialect {
×
755
                        case DialectPostgres:
×
756
                                column.ColumnIdentity = ALWAYS_IDENTITY
×
757
                        case DialectSQLServer:
×
758
                                column.ColumnIdentity = IDENTITY
×
759
                        }
760
                case "notnull":
259✔
761
                        column.IsNotNull = true
259✔
762
                case "onupdatecurrenttimestamp":
20✔
763
                        if p.dialect != DialectMySQL {
35✔
764
                                continue
15✔
765
                        }
766
                        column.OnUpdateCurrentTimestamp = true
5✔
767
                case "collate":
30✔
768
                        column.CollationName = modifier.RawValue
30✔
769
                case "default":
99✔
770
                        if p.dialect != DialectPostgres && !isLiteral(modifier.RawValue) {
142✔
771
                                column.ColumnDefault = wrapBrackets(modifier.RawValue)
43✔
772
                                continue
43✔
773
                        }
774
                        column.ColumnDefault = modifier.RawValue
56✔
775
                        if p.dialect == DialectSQLServer {
71✔
776
                                if strings.EqualFold(column.ColumnDefault, "TRUE") {
15✔
777
                                        column.ColumnDefault = "1"
×
778
                                } else if strings.EqualFold(column.ColumnDefault, "FALSE") {
15✔
779
                                        column.ColumnDefault = "0"
×
780
                                }
×
781
                        }
782
                case "generated":
2✔
783
                        column.IsGenerated = true
2✔
784
                case "dialect":
60✔
785
                        if modifier.RawValue == "" {
60✔
786
                                loc.keys = []string{modifier.Name}
×
787
                                p.report(loc, "dialect value cannot be blank")
×
788
                                continue
×
789
                        }
790
                        column.Ignore = true
60✔
791
                        dialects = strings.Split(modifier.RawValue, ",")
60✔
792
                        for _, dialect := range dialects {
120✔
793
                                if p.dialect == dialect {
60✔
794
                                        column.Ignore = false
×
795
                                        break
×
796
                                }
797
                        }
798
                case "index":
149✔
799
                        loc.keys = []string{modifier.Name}
149✔
800
                        p.parseIndexModifier(table, []string{columnName}, loc, modifier)
149✔
801
                case "primarykey", "unique":
430✔
802
                        loc.keys = []string{modifier.Name}
430✔
803
                        p.parsePrimaryKeyUniqueModifier(table, []string{columnName}, loc, modifier)
430✔
804
                case "foreignkey":
×
805
                        loc.keys = []string{modifier.Name}
×
806
                        p.parseForeignKeyModifier(table, loc, modifier)
×
807
                case "references":
497✔
808
                        loc.keys = []string{modifier.Name}
497✔
809
                        p.parseReferencesModifier(table, columnName, loc, modifier)
497✔
810
                default:
×
811
                        p.report(loc, "unknown modifier "+strconv.Quote(modifier.Name))
×
812
                }
813
        }
814
}
815

816
func (p *StructParser) parseTableModifiers(table *Table, loc location, modifiers []Modifier) {
590✔
817
        var dialects []string
590✔
818
        for i := range modifiers {
660✔
819
                modifier := &modifiers[i]
70✔
820
                if len(modifier.Dialects) == 0 {
140✔
821
                        modifier.Dialects = dialects
70✔
822
                }
70✔
823
                switch modifier.Name {
70✔
824
                case "dialect":
40✔
825
                        if modifier.RawValue == "" {
40✔
826
                                loc.keys = []string{modifier.Name}
×
827
                                p.report(loc, "dialect value cannot be blank")
×
828
                                continue
×
829
                        }
830
                        table.Ignore = true
40✔
831
                        dialects = strings.Split(modifier.RawValue, ",")
40✔
832
                        for _, dialect := range dialects {
80✔
833
                                if p.dialect == dialect {
40✔
834
                                        table.Ignore = false
×
835
                                        break
×
836
                                }
837
                        }
838
                case "index":
23✔
839
                        loc.keys = []string{modifier.Name}
23✔
840
                        p.parseIndexModifier(table, nil, loc, modifier)
23✔
841
                case "primarykey", "unique":
5✔
842
                        loc.keys = []string{modifier.Name}
5✔
843
                        p.parsePrimaryKeyUniqueModifier(table, nil, loc, modifier)
5✔
844
                case "foreignkey":
1✔
845
                        loc.keys = []string{modifier.Name}
1✔
846
                        p.parseForeignKeyModifier(table, loc, modifier)
1✔
847
                case "virtual":
1✔
848
                        if p.dialect != DialectSQLite {
1✔
849
                                continue
×
850
                        }
851
                        table.IsVirtual = true
1✔
852
                default:
×
853
                        p.report(loc, "unknown modifier "+strconv.Quote(modifier.Name))
×
854
                }
855
        }
856
}
857

858
func (p *StructParser) report(loc location, msg string) {
×
859
        p.parserDiagnostics.locs = append(p.parserDiagnostics.locs, loc)
×
860
        p.parserDiagnostics.msgs = append(p.parserDiagnostics.msgs, msg)
×
861
}
×
862

863
type location struct {
864
        pos        token.Pos
865
        structName string
866
        fieldName  string
867
        keys       []string
868
}
869

870
type parserDiagnostics struct {
871
        fset *token.FileSet
872
        locs []location
873
        msgs []string
874
}
875

876
// Diagnostics returns the errors encountered after calling WriteCatalog in a
877
// structured format.
878
func (p *StructParser) Diagnostics() ([]token.Pos, []string) {
×
879
        positions := make([]token.Pos, 0, len(p.parserDiagnostics.msgs))
×
880
        msgs := make([]string, 0, len(p.parserDiagnostics.msgs))
×
881
        for i, msg := range p.parserDiagnostics.msgs {
×
882
                var loc location
×
883
                if i < len(p.parserDiagnostics.locs) {
×
884
                        loc = p.parserDiagnostics.locs[i]
×
885
                }
×
886
                n := len(loc.structName) + len(".") + len(loc.fieldName) + len(": ")
×
887
                for _, key := range loc.keys {
×
888
                        n += len(".") + len(key)
×
889
                }
×
890
                n += len(": ") + len(msg)
×
891
                var b strings.Builder
×
892
                b.Grow(n)
×
893
                if loc.structName != "" && loc.fieldName != "" {
×
894
                        b.WriteString(loc.structName + "." + loc.fieldName + ": ")
×
895
                }
×
896
                for j, key := range loc.keys {
×
897
                        if j > 0 {
×
898
                                b.WriteString(".")
×
899
                        }
×
900
                        b.WriteString(key)
×
901
                }
902
                b.WriteString(": " + msg)
×
903
                positions = append(positions, loc.pos)
×
904
                msgs = append(msgs, b.String())
×
905
        }
906
        return positions, msgs
×
907
}
908

909
// Error returns the errors encountered after calling WriteCatalog.
910
func (p *StructParser) Error() error {
466✔
911
        if len(p.parserDiagnostics.msgs) > 0 {
466✔
912
                return p.parserDiagnostics
×
913
        }
×
914
        return nil
466✔
915
}
916

917
// Error implements the error interface.
918
func (d *parserDiagnostics) Error() string {
×
919
        buf := bufpool.Get().(*bytes.Buffer)
×
920
        buf.Reset()
×
921
        defer bufpool.Put(buf)
×
922
        for i, msg := range d.msgs {
×
923
                if i > 0 {
×
924
                        buf.WriteString("\n")
×
925
                }
×
926
                if i >= len(d.locs) {
×
927
                        buf.WriteString(msg)
×
928
                        continue
×
929
                }
930
                loc := d.locs[i]
×
931
                if d.fset != nil {
×
932
                        pos := d.fset.Position(loc.pos)
×
933
                        if pos.IsValid() {
×
934
                                buf.WriteString(pos.String() + ": ")
×
935
                        }
×
936
                }
937
                if loc.structName != "" && loc.fieldName != "" {
×
938
                        buf.WriteString(loc.structName + "." + loc.fieldName + ": ")
×
939
                }
×
940
                if len(loc.keys) > 0 {
×
941
                        for j, key := range loc.keys {
×
942
                                if j > 0 {
×
943
                                        buf.WriteByte('.')
×
944
                                }
×
945
                                buf.WriteString(key)
×
946
                        }
947
                        buf.WriteString(": ")
×
948
                }
949
                buf.WriteString(msg)
×
950
        }
951
        return buf.String()
×
952
}
953

954
// Analyzer is an &analysis.Analyzer which can be used in a custom linter.
955
var Analyzer = &analysis.Analyzer{
956
        Name:     "ddl",
957
        Doc:      "validates ddl structs",
958
        Requires: []*analysis.Analyzer{inspect.Analyzer},
959
        Run: func(pass *analysis.Pass) (any, error) {
×
960
                inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
×
961
                nodeFilter := []ast.Node{(*ast.TypeSpec)(nil)}
×
962
                p := NewStructParser(pass.Fset)
×
963
                inspect.Preorder(nodeFilter, p.VisitStruct)
×
964
                var catalog Catalog
×
965
                _ = p.WriteCatalog(&catalog)
×
966
                positions, msgs := p.Diagnostics()
×
967
                if len(msgs) == 0 {
×
968
                        return nil, nil
×
969
                }
×
970
                for i, msg := range msgs {
×
971
                        var pos token.Pos
×
972
                        if i < len(positions) {
×
973
                                pos = positions[i]
×
974
                        }
×
975
                        pass.Reportf(pos, msg)
×
976
                }
977
                return nil, nil
×
978
        },
979
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc