• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 16295271531

15 Jul 2025 01:55PM UTC coverage: 56.553% (+0.2%) from 56.314%
16295271531

Pull #1294

github

dsyers
lint
Pull Request #1294: Triggers

790 of 1325 new or added lines in 24 files covered. (59.62%)

4 existing lines in 4 files now uncovered.

5277 of 9331 relevant lines covered (56.55%)

55.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.84
/database/sqlserver/sqlserver.go
1
package sqlserver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "fmt"
7
        "io"
8
        nurl "net/url"
9
        "strconv"
10
        "strings"
11

12
        "go.uber.org/atomic"
13

14
        "github.com/Azure/go-autorest/autorest/adal"
15
        "github.com/golang-migrate/migrate/v4"
16
        "github.com/golang-migrate/migrate/v4/database"
17
        "github.com/hashicorp/go-multierror"
18
        mssql "github.com/microsoft/go-mssqldb" // mssql support
19
)
20

21
func init() {
2✔
22
        database.Register("sqlserver", &SQLServer{})
2✔
23
}
2✔
24

25
// DefaultMigrationsTable is the name of the migrations table in the database
26
var DefaultMigrationsTable = "schema_migrations"
27

28
var (
29
        ErrNilConfig                 = fmt.Errorf("no config")
30
        ErrNoDatabaseName            = fmt.Errorf("no database name")
31
        ErrNoSchema                  = fmt.Errorf("no schema")
32
        ErrDatabaseDirty             = fmt.Errorf("database is dirty")
33
        ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
34
)
35

36
var lockErrorMap = map[int]string{
37
        -1:   "The lock request timed out.",
38
        -2:   "The lock request was canceled.",
39
        -3:   "The lock request was chosen as a deadlock victim.",
40
        -999: "Parameter validation or other call error.",
41
}
42

43
// Config for database
44
type Config struct {
45
        MigrationsTable string
46
        DatabaseName    string
47
        SchemaName      string
48

49
        Triggers map[string]func(response interface{}) error
50
}
51

52
// SQL Server connection
53
type SQLServer struct {
54
        // Locking and unlocking need to use the same connection
55
        conn     *sql.Conn
56
        db       *sql.DB
57
        isLocked atomic.Bool
58

59
        // Open and WithInstance need to garantuee that config is never nil
60
        config *Config
61
}
62

63
type TriggerResponse struct {
64
        Driver  *SQLServer
65
        Config  *Config
66
        Trigger string
67
        Detail  interface{}
68
}
69

70
// WithInstance returns a database instance from an already created database connection.
71
//
72
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
73
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
32✔
74
        if config == nil {
32✔
75
                return nil, ErrNilConfig
×
76
        }
×
77

78
        if err := instance.Ping(); err != nil {
40✔
79
                return nil, err
8✔
80
        }
8✔
81

82
        if config.DatabaseName == "" {
48✔
83
                query := `SELECT DB_NAME()`
24✔
84
                var databaseName string
24✔
85
                if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
24✔
86
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
87
                }
×
88

89
                if len(databaseName) == 0 {
24✔
90
                        return nil, ErrNoDatabaseName
×
91
                }
×
92

93
                config.DatabaseName = databaseName
24✔
94
        }
95

96
        if config.SchemaName == "" {
48✔
97
                query := `SELECT SCHEMA_NAME()`
24✔
98
                var schemaName string
24✔
99
                if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
24✔
100
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
101
                }
×
102

103
                if len(schemaName) == 0 {
24✔
104
                        return nil, ErrNoSchema
×
105
                }
×
106

107
                config.SchemaName = schemaName
24✔
108
        }
109

110
        if len(config.MigrationsTable) == 0 {
48✔
111
                config.MigrationsTable = DefaultMigrationsTable
24✔
112
        }
24✔
113

114
        conn, err := instance.Conn(context.Background())
24✔
115

24✔
116
        if err != nil {
24✔
117
                return nil, err
×
118
        }
×
119

120
        ss := &SQLServer{
24✔
121
                conn:   conn,
24✔
122
                db:     instance,
24✔
123
                config: config,
24✔
124
        }
24✔
125

24✔
126
        if err := ss.ensureVersionTable(); err != nil {
24✔
127
                return nil, err
×
128
        }
×
129

130
        return ss, nil
24✔
131
}
132

133
// Open a connection to the database.
134
func (ss *SQLServer) Open(url string) (database.Driver, error) {
36✔
135
        purl, err := nurl.Parse(url)
36✔
136
        if err != nil {
36✔
137
                return nil, err
×
138
        }
×
139

140
        useMsiParam := purl.Query().Get("useMsi")
36✔
141
        useMsi := false
36✔
142
        if len(useMsiParam) > 0 {
52✔
143
                useMsi, err = strconv.ParseBool(useMsiParam)
16✔
144
                if err != nil {
16✔
145
                        return nil, err
×
146
                }
×
147
        }
148

149
        if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
40✔
150
                return nil, ErrMultipleAuthOptionsPassed
4✔
151
        }
4✔
152

153
        filteredURL := migrate.FilterCustomQuery(purl).String()
32✔
154

32✔
155
        var db *sql.DB
32✔
156
        if useMsi {
36✔
157
                resource := getAADResourceFromServerUri(purl)
4✔
158
                tokenProvider, err := getMSITokenProvider(resource)
4✔
159
                if err != nil {
4✔
160
                        return nil, err
×
161
                }
×
162

163
                connector, err := mssql.NewAccessTokenConnector(
4✔
164
                        filteredURL, tokenProvider)
4✔
165
                if err != nil {
4✔
166
                        return nil, err
×
167
                }
×
168

169
                db = sql.OpenDB(connector)
4✔
170

171
        } else {
28✔
172
                db, err = sql.Open("sqlserver", filteredURL)
28✔
173
                if err != nil {
28✔
174
                        return nil, err
×
175
                }
×
176
        }
177

178
        migrationsTable := purl.Query().Get("x-migrations-table")
32✔
179

32✔
180
        px, err := WithInstance(db, &Config{
32✔
181
                DatabaseName:    purl.Path,
32✔
182
                MigrationsTable: migrationsTable,
32✔
183
        })
32✔
184

32✔
185
        if err != nil {
40✔
186
                return nil, err
8✔
187
        }
8✔
188

189
        return px, nil
24✔
190
}
191

192
// Close the database connection
193
func (ss *SQLServer) Close() error {
20✔
194
        connErr := ss.conn.Close()
20✔
195
        dbErr := ss.db.Close()
20✔
196
        if connErr != nil || dbErr != nil {
20✔
197
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
198
        }
×
199
        return nil
20✔
200
}
201

202
func (ss *SQLServer) AddTriggers(t map[string]func(response interface{}) error) {
4✔
203
        ss.config.Triggers = t
4✔
204
}
4✔
205

206
func (ss *SQLServer) Trigger(name string, detail interface{}) error {
444✔
207
        if ss.config.Triggers == nil {
672✔
208
                return nil
228✔
209
        }
228✔
210

211
        if trigger, ok := ss.config.Triggers[name]; ok {
288✔
212
                return trigger(TriggerResponse{
72✔
213
                        Driver:  ss,
72✔
214
                        Config:  ss.config,
72✔
215
                        Trigger: name,
72✔
216
                        Detail:  detail,
72✔
217
                })
72✔
218
        }
72✔
219

220
        return nil
144✔
221
}
222

223
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
224
func (ss *SQLServer) Lock() error {
76✔
225
        return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
140✔
226
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
64✔
227
                if err != nil {
64✔
228
                        return err
×
229
                }
×
230

231
                // This will block until the lock is acquired.
232
                // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
233
                query := `
64✔
234
                DECLARE @lockResult int;
64✔
235
                EXEC @lockResult = sp_getapplock @Resource = @p1, @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1;
64✔
236
                SELECT @lockResult;`
64✔
237

64✔
238
                var status int
64✔
239
                if err = ss.conn.QueryRowContext(context.Background(), query, aid).Scan(&status); err == nil && status > -1 {
128✔
240
                        return nil
64✔
241
                } else if err != nil {
64✔
242
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
×
243
                } else {
×
244
                        errorDescription, ok := lockErrorMap[status]
×
245
                        if !ok {
×
246
                                errorDescription = "Unknown error"
×
247
                        }
×
248
                        return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, errorDescription), Query: []byte(query)}
×
249
                }
250
        })
251
}
252

253
// Unlock froms the migration lock from the database
254
func (ss *SQLServer) Unlock() error {
64✔
255
        return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
128✔
256
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
64✔
257
                if err != nil {
64✔
258
                        return err
×
259
                }
×
260

261
                // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
262
                query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
64✔
263
                if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
64✔
264
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
265
                }
×
266

267
                return nil
64✔
268
        })
269
}
270

271
// Run the migrations for the database
272
func (ss *SQLServer) Run(migration io.Reader) error {
56✔
273
        migr, err := io.ReadAll(migration)
56✔
274
        if err != nil {
56✔
275
                return err
×
276
        }
×
277

278
        // run migration
279
        query := string(migr[:])
56✔
280
        if err := ss.Trigger(database.TrigRunPre, struct {
56✔
281
                Query string
56✔
282
        }{Query: query}); err != nil {
56✔
NEW
283
                return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"}
×
NEW
284
        }
×
285
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
60✔
286
                if msErr, ok := err.(mssql.Error); ok {
8✔
287
                        message := fmt.Sprintf("migration failed: %s", msErr.Message)
4✔
288
                        if msErr.ProcName != "" {
4✔
289
                                message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
×
290
                        }
×
291
                        return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
4✔
292
                }
293
                return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
×
294
        }
295
        if err := ss.Trigger(database.TrigRunPost, struct {
52✔
296
                Query string
52✔
297
        }{Query: query}); err != nil {
52✔
NEW
298
                return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"}
×
NEW
299
        }
×
300

301
        return nil
52✔
302
}
303

304
// SetVersion for the current database
305
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
144✔
306

144✔
307
        tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
144✔
308
        if err != nil {
144✔
309
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
×
310
        }
×
311

312
        if err := ss.Trigger(database.TrigSetVersionPre, struct {
144✔
313
                Version int
144✔
314
                Dirty   bool
144✔
315
        }{Version: version, Dirty: dirty}); err != nil {
144✔
NEW
316
                if errRollback := tx.Rollback(); errRollback != nil {
×
NEW
317
                        err = multierror.Append(err, errRollback)
×
NEW
318
                }
×
NEW
319
                return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"}
×
320
        }
321

322
        query := `TRUNCATE TABLE ` + ss.getMigrationTable()
144✔
323
        if _, err := tx.Exec(query); err != nil {
144✔
324
                if errRollback := tx.Rollback(); errRollback != nil {
×
325
                        err = multierror.Append(err, errRollback)
×
326
                }
×
327
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
328
        }
329

330
        // Also re-write the schema version for nil dirty versions to prevent
331
        // empty schema version for failed down migration on the first migration
332
        // See: https://github.com/golang-migrate/migrate/issues/330
333
        if version >= 0 || (version == database.NilVersion && dirty) {
276✔
334
                var dirtyBit int
132✔
335
                if dirty {
204✔
336
                        dirtyBit = 1
72✔
337
                }
72✔
338
                query = `INSERT INTO ` + ss.getMigrationTable() + ` (version, dirty) VALUES (@p1, @p2)`
132✔
339
                if _, err := tx.Exec(query, version, dirtyBit); err != nil {
132✔
340
                        if errRollback := tx.Rollback(); errRollback != nil {
×
341
                                err = multierror.Append(err, errRollback)
×
342
                        }
×
343
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
344
                }
345
        }
346

347
        if err := ss.Trigger(database.TrigSetVersionPost, struct {
144✔
348
                Version int
144✔
349
                Dirty   bool
144✔
350
        }{Version: version, Dirty: dirty}); err != nil {
144✔
NEW
351
                if errRollback := tx.Rollback(); errRollback != nil {
×
NEW
352
                        err = multierror.Append(err, errRollback)
×
NEW
353
                }
×
NEW
354
                return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"}
×
355
        }
356

357
        if err := tx.Commit(); err != nil {
144✔
358
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
359
        }
×
360

361
        return nil
144✔
362
}
363

364
// Version of the current database state
365
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
88✔
366
        query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable()
88✔
367
        err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
88✔
368
        switch {
88✔
369
        case err == sql.ErrNoRows:
28✔
370
                return database.NilVersion, false, nil
28✔
371

372
        case err != nil:
×
373
                // FIXME: convert to MSSQL error
×
374
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
×
375

376
        default:
60✔
377
                return version, dirty, nil
60✔
378
        }
379
}
380

381
// Drop all tables from the database.
382
func (ss *SQLServer) Drop() error {
16✔
383

16✔
384
        // drop all referential integrity constraints
16✔
385
        query := `
16✔
386
        DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
16✔
387

16✔
388
        SET @Cursor = CURSOR FAST_FORWARD FOR
16✔
389
        SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
16✔
390
        FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
16✔
391
        LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
16✔
392

16✔
393
        OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
16✔
394

16✔
395
        WHILE (@@FETCH_STATUS = 0)
16✔
396
        BEGIN
16✔
397
        Exec sp_executesql @Sql
16✔
398
        FETCH NEXT FROM @Cursor INTO @Sql
16✔
399
        END
16✔
400

16✔
401
        CLOSE @Cursor DEALLOCATE @Cursor`
16✔
402

16✔
403
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
16✔
404
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
405
        }
×
406

407
        // drop the tables
408
        query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
16✔
409
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
16✔
410
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
411
        }
×
412

413
        return nil
16✔
414
}
415

416
func (ss *SQLServer) ensureVersionTable() (err error) {
24✔
417
        if err = ss.Lock(); err != nil {
24✔
418
                return err
×
419
        }
×
420

421
        defer func() {
48✔
422
                if e := ss.Unlock(); e != nil {
24✔
423
                        if err == nil {
×
424
                                err = e
×
425
                        } else {
×
426
                                err = multierror.Append(err, e)
×
427
                        }
×
428
                }
429
        }()
430

431
        if err := ss.Trigger(database.TrigVersionTablePre, nil); err != nil {
24✔
NEW
432
                return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"}
×
NEW
433
        }
×
434

435
        query := `IF NOT EXISTS
24✔
436
        (SELECT *
24✔
437
                 FROM sysobjects
24✔
438
                WHERE id = object_id(N'` + ss.getMigrationTable() + `')
24✔
439
                        AND OBJECTPROPERTY(id, N'IsUserTable') = 1
24✔
440
        )
24✔
441
        CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
24✔
442

24✔
443
        if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
24✔
444
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
445
        }
×
446

447
        if err := ss.Trigger(database.TrigVersionTablePost, nil); err != nil {
24✔
NEW
448
                return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"}
×
NEW
449
        }
×
450

451
        return nil
24✔
452
}
453

454
func (ss *SQLServer) getMigrationTable() string {
412✔
455
        return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable)
412✔
456
}
412✔
457

458
func getMSITokenProvider(resource string) (func() (string, error), error) {
4✔
459
        msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
4✔
460
        if err != nil {
4✔
461
                return nil, err
×
462
        }
×
463

464
        return func() (string, error) {
8✔
465
                err := msi.EnsureFresh()
4✔
466
                if err != nil {
8✔
467
                        return "", err
4✔
468
                }
4✔
469
                token := msi.OAuthToken()
×
470
                return token, nil
×
471
        }, nil
472
}
473

474
// The sql server resource can change across clouds so get it
475
// dynamically based on the server uri.
476
// ex. <server name>.database.windows.net -> https://database.windows.net
477
func getAADResourceFromServerUri(purl *nurl.URL) string {
4✔
478
        return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
4✔
479
}
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc