• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 17662383123

12 Sep 2025 02:06AM UTC coverage: 50.017% (-4.0%) from 54.037%
17662383123

Pull #1318

github

daniel-garcia
store db migrates in the database if it supports it
Pull Request #1318: WIP: Store migrates in db

163 of 653 new or added lines in 7 files covered. (24.96%)

180 existing lines in 3 files now uncovered.

4362 of 8721 relevant lines covered (50.02%)

44.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

1.17
/database/sqlserver/sqlserver.go
1
package sqlserver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "fmt"
7
        "io"
8
        nurl "net/url"
9
        "strconv"
10
        "strings"
11
        "sync/atomic"
12

13
        "github.com/Azure/go-autorest/autorest/adal"
14
        "github.com/golang-migrate/migrate/v4"
15
        "github.com/golang-migrate/migrate/v4/database"
16
        "github.com/hashicorp/go-multierror"
17
        mssql "github.com/microsoft/go-mssqldb" // mssql support
18
)
19

20
func init() {
2✔
21
        database.Register("sqlserver", &SQLServer{})
2✔
22
}
2✔
23

24
// DefaultMigrationsTable is the name of the migrations table in the database
25
var DefaultMigrationsTable = "schema_migrations"
26

27
var (
28
        ErrNilConfig                 = fmt.Errorf("no config")
29
        ErrNoDatabaseName            = fmt.Errorf("no database name")
30
        ErrNoSchema                  = fmt.Errorf("no schema")
31
        ErrDatabaseDirty             = fmt.Errorf("database is dirty")
32
        ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
33
)
34

35
var lockErrorMap = map[int]string{
36
        -1:   "The lock request timed out.",
37
        -2:   "The lock request was canceled.",
38
        -3:   "The lock request was chosen as a deadlock victim.",
39
        -999: "Parameter validation or other call error.",
40
}
41

42
// Config for database
43
type Config struct {
44
        MigrationsTable string
45
        DatabaseName    string
46
        SchemaName      string
47
}
48

49
// SQL Server connection
50
type SQLServer struct {
51
        // Locking and unlocking need to use the same connection
52
        conn     *sql.Conn
53
        db       *sql.DB
54
        isLocked atomic.Bool
55

56
        // Open and WithInstance need to garantuee that config is never nil
57
        config *Config
58
}
59

60
// WithInstance returns a database instance from an already created database connection.
61
//
62
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
UNCOV
63
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
×
UNCOV
64
        if config == nil {
×
65
                return nil, ErrNilConfig
×
66
        }
×
67

UNCOV
68
        if err := instance.Ping(); err != nil {
×
UNCOV
69
                return nil, err
×
UNCOV
70
        }
×
71

UNCOV
72
        if config.DatabaseName == "" {
×
UNCOV
73
                query := `SELECT DB_NAME()`
×
UNCOV
74
                var databaseName string
×
UNCOV
75
                if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
×
76
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
77
                }
×
78

UNCOV
79
                if len(databaseName) == 0 {
×
80
                        return nil, ErrNoDatabaseName
×
81
                }
×
82

UNCOV
83
                config.DatabaseName = databaseName
×
84
        }
85

UNCOV
86
        if config.SchemaName == "" {
×
UNCOV
87
                query := `SELECT SCHEMA_NAME()`
×
UNCOV
88
                var schemaName string
×
UNCOV
89
                if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
×
90
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
91
                }
×
92

UNCOV
93
                if len(schemaName) == 0 {
×
94
                        return nil, ErrNoSchema
×
95
                }
×
96

UNCOV
97
                config.SchemaName = schemaName
×
98
        }
99

UNCOV
100
        if len(config.MigrationsTable) == 0 {
×
UNCOV
101
                config.MigrationsTable = DefaultMigrationsTable
×
UNCOV
102
        }
×
103

UNCOV
104
        conn, err := instance.Conn(context.Background())
×
UNCOV
105

×
UNCOV
106
        if err != nil {
×
107
                return nil, err
×
108
        }
×
109

UNCOV
110
        ss := &SQLServer{
×
UNCOV
111
                conn:   conn,
×
UNCOV
112
                db:     instance,
×
UNCOV
113
                config: config,
×
UNCOV
114
        }
×
UNCOV
115

×
UNCOV
116
        if err := ss.ensureVersionTable(); err != nil {
×
117
                return nil, err
×
118
        }
×
119

UNCOV
120
        return ss, nil
×
121
}
122

123
// Open a connection to the database.
UNCOV
124
func (ss *SQLServer) Open(url string) (database.Driver, error) {
×
UNCOV
125
        purl, err := nurl.Parse(url)
×
UNCOV
126
        if err != nil {
×
127
                return nil, err
×
128
        }
×
129

UNCOV
130
        useMsiParam := purl.Query().Get("useMsi")
×
UNCOV
131
        useMsi := false
×
UNCOV
132
        if len(useMsiParam) > 0 {
×
UNCOV
133
                useMsi, err = strconv.ParseBool(useMsiParam)
×
UNCOV
134
                if err != nil {
×
135
                        return nil, err
×
136
                }
×
137
        }
138

UNCOV
139
        if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
×
UNCOV
140
                return nil, ErrMultipleAuthOptionsPassed
×
UNCOV
141
        }
×
142

UNCOV
143
        filteredURL := migrate.FilterCustomQuery(purl).String()
×
UNCOV
144

×
UNCOV
145
        var db *sql.DB
×
UNCOV
146
        if useMsi {
×
UNCOV
147
                resource := getAADResourceFromServerUri(purl)
×
UNCOV
148
                tokenProvider, err := getMSITokenProvider(resource)
×
UNCOV
149
                if err != nil {
×
150
                        return nil, err
×
151
                }
×
152

UNCOV
153
                connector, err := mssql.NewAccessTokenConnector(
×
UNCOV
154
                        filteredURL, tokenProvider)
×
UNCOV
155
                if err != nil {
×
156
                        return nil, err
×
157
                }
×
158

UNCOV
159
                db = sql.OpenDB(connector)
×
160

UNCOV
161
        } else {
×
UNCOV
162
                db, err = sql.Open("sqlserver", filteredURL)
×
UNCOV
163
                if err != nil {
×
164
                        return nil, err
×
165
                }
×
166
        }
167

UNCOV
168
        migrationsTable := purl.Query().Get("x-migrations-table")
×
UNCOV
169

×
UNCOV
170
        px, err := WithInstance(db, &Config{
×
UNCOV
171
                DatabaseName:    purl.Path,
×
UNCOV
172
                MigrationsTable: migrationsTable,
×
UNCOV
173
        })
×
UNCOV
174

×
UNCOV
175
        if err != nil {
×
UNCOV
176
                return nil, err
×
UNCOV
177
        }
×
178

UNCOV
179
        return px, nil
×
180
}
181

182
// Close the database connection
UNCOV
183
func (ss *SQLServer) Close() error {
×
UNCOV
184
        connErr := ss.conn.Close()
×
UNCOV
185
        dbErr := ss.db.Close()
×
UNCOV
186
        if connErr != nil || dbErr != nil {
×
187
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
188
        }
×
UNCOV
189
        return nil
×
190
}
191

192
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
UNCOV
193
func (ss *SQLServer) Lock() error {
×
UNCOV
194
        return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
×
UNCOV
195
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
×
UNCOV
196
                if err != nil {
×
197
                        return err
×
198
                }
×
199

200
                // This will block until the lock is acquired.
201
                // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
UNCOV
202
                query := `
×
UNCOV
203
                DECLARE @lockResult int;
×
UNCOV
204
                EXEC @lockResult = sp_getapplock @Resource = @p1, @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1;
×
UNCOV
205
                SELECT @lockResult;`
×
UNCOV
206

×
UNCOV
207
                var status int
×
UNCOV
208
                if err = ss.conn.QueryRowContext(context.Background(), query, aid).Scan(&status); err == nil && status > -1 {
×
UNCOV
209
                        return nil
×
UNCOV
210
                } else if err != nil {
×
211
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
×
212
                } else {
×
213
                        errorDescription, ok := lockErrorMap[status]
×
214
                        if !ok {
×
215
                                errorDescription = "Unknown error"
×
216
                        }
×
217
                        return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, errorDescription), Query: []byte(query)}
×
218
                }
219
        })
220
}
221

222
// Unlock froms the migration lock from the database
UNCOV
223
func (ss *SQLServer) Unlock() error {
×
UNCOV
224
        return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
×
UNCOV
225
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
×
UNCOV
226
                if err != nil {
×
227
                        return err
×
228
                }
×
229

230
                // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
UNCOV
231
                query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
×
UNCOV
232
                if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
×
233
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
234
                }
×
235

UNCOV
236
                return nil
×
237
        })
238
}
239

240
// Run the migrations for the database
UNCOV
241
func (ss *SQLServer) Run(migration io.Reader) error {
×
UNCOV
242
        migr, err := io.ReadAll(migration)
×
UNCOV
243
        if err != nil {
×
244
                return err
×
245
        }
×
246

247
        // run migration
UNCOV
248
        query := string(migr[:])
×
UNCOV
249
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
×
UNCOV
250
                if msErr, ok := err.(mssql.Error); ok {
×
UNCOV
251
                        message := fmt.Sprintf("migration failed: %s", msErr.Message)
×
UNCOV
252
                        if msErr.ProcName != "" {
×
253
                                message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
×
254
                        }
×
UNCOV
255
                        return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
×
256
                }
257
                return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
×
258
        }
259

UNCOV
260
        return nil
×
261
}
262

263
// SetVersion for the current database
UNCOV
264
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
×
UNCOV
265

×
UNCOV
266
        tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
×
UNCOV
267
        if err != nil {
×
268
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
×
269
        }
×
270

UNCOV
271
        query := `TRUNCATE TABLE ` + ss.getMigrationTable()
×
UNCOV
272
        if _, err := tx.Exec(query); err != nil {
×
273
                if errRollback := tx.Rollback(); errRollback != nil {
×
274
                        err = multierror.Append(err, errRollback)
×
275
                }
×
276
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
277
        }
278

279
        // Also re-write the schema version for nil dirty versions to prevent
280
        // empty schema version for failed down migration on the first migration
281
        // See: https://github.com/golang-migrate/migrate/issues/330
UNCOV
282
        if version >= 0 || (version == database.NilVersion && dirty) {
×
UNCOV
283
                var dirtyBit int
×
UNCOV
284
                if dirty {
×
UNCOV
285
                        dirtyBit = 1
×
UNCOV
286
                }
×
UNCOV
287
                query = `INSERT INTO ` + ss.getMigrationTable() + ` (version, dirty) VALUES (@p1, @p2)`
×
UNCOV
288
                if _, err := tx.Exec(query, version, dirtyBit); err != nil {
×
289
                        if errRollback := tx.Rollback(); errRollback != nil {
×
290
                                err = multierror.Append(err, errRollback)
×
291
                        }
×
292
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
293
                }
294
        }
295

UNCOV
296
        if err := tx.Commit(); err != nil {
×
297
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
298
        }
×
299

UNCOV
300
        return nil
×
301
}
302

303
// Version of the current database state
UNCOV
304
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
×
UNCOV
305
        query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable()
×
UNCOV
306
        err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
×
UNCOV
307
        switch {
×
UNCOV
308
        case err == sql.ErrNoRows:
×
UNCOV
309
                return database.NilVersion, false, nil
×
310

311
        case err != nil:
×
312
                // FIXME: convert to MSSQL error
×
313
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
×
314

UNCOV
315
        default:
×
UNCOV
316
                return version, dirty, nil
×
317
        }
318
}
319

320
// Drop all tables from the database.
UNCOV
321
func (ss *SQLServer) Drop() error {
×
UNCOV
322

×
UNCOV
323
        // drop all referential integrity constraints
×
UNCOV
324
        query := `
×
UNCOV
325
        DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
×
UNCOV
326

×
UNCOV
327
        SET @Cursor = CURSOR FAST_FORWARD FOR
×
UNCOV
328
        SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
×
UNCOV
329
        FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
×
UNCOV
330
        LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
×
UNCOV
331

×
UNCOV
332
        OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
×
UNCOV
333

×
UNCOV
334
        WHILE (@@FETCH_STATUS = 0)
×
UNCOV
335
        BEGIN
×
UNCOV
336
        Exec sp_executesql @Sql
×
UNCOV
337
        FETCH NEXT FROM @Cursor INTO @Sql
×
UNCOV
338
        END
×
UNCOV
339

×
UNCOV
340
        CLOSE @Cursor DEALLOCATE @Cursor`
×
UNCOV
341

×
UNCOV
342
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
×
343
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
344
        }
×
345

346
        // drop the tables
UNCOV
347
        query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
×
UNCOV
348
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
×
349
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
350
        }
×
351

UNCOV
352
        return nil
×
353
}
354

UNCOV
355
func (ss *SQLServer) ensureVersionTable() (err error) {
×
UNCOV
356
        if err = ss.Lock(); err != nil {
×
357
                return err
×
358
        }
×
359

UNCOV
360
        defer func() {
×
UNCOV
361
                if e := ss.Unlock(); e != nil {
×
362
                        if err == nil {
×
363
                                err = e
×
364
                        } else {
×
365
                                err = multierror.Append(err, e)
×
366
                        }
×
367
                }
368
        }()
369

UNCOV
370
        query := `IF NOT EXISTS
×
UNCOV
371
        (SELECT *
×
UNCOV
372
                 FROM sysobjects
×
UNCOV
373
                WHERE id = object_id(N'` + ss.getMigrationTable() + `')
×
UNCOV
374
                        AND OBJECTPROPERTY(id, N'IsUserTable') = 1
×
UNCOV
375
        )
×
UNCOV
376
        CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
×
UNCOV
377

×
UNCOV
378
        if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
×
379
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
380
        }
×
381

UNCOV
382
        return nil
×
383
}
384

UNCOV
385
func (ss *SQLServer) getMigrationTable() string {
×
UNCOV
386
        return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable)
×
UNCOV
387
}
×
388

UNCOV
389
func getMSITokenProvider(resource string) (func() (string, error), error) {
×
UNCOV
390
        msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
×
UNCOV
391
        if err != nil {
×
392
                return nil, err
×
393
        }
×
394

UNCOV
395
        return func() (string, error) {
×
UNCOV
396
                err := msi.EnsureFresh()
×
UNCOV
397
                if err != nil {
×
UNCOV
398
                        return "", err
×
UNCOV
399
                }
×
400
                token := msi.OAuthToken()
×
401
                return token, nil
×
402
        }, nil
403
}
404

405
// The sql server resource can change across clouds so get it
406
// dynamically based on the server uri.
407
// ex. <server name>.database.windows.net -> https://database.windows.net
UNCOV
408
func getAADResourceFromServerUri(purl *nurl.URL) string {
×
UNCOV
409
        return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
×
UNCOV
410
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc