• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 5186663860

pending completion
5186663860

Pull #932

github

longit644
Add Golang function as a source of migration
Pull Request #932: Add Golang function as a source of migration

252 of 370 new or added lines in 32 files covered. (68.11%)

4 existing lines in 4 files now uncovered.

4214 of 7231 relevant lines covered (58.28%)

61.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.72
/database/sqlserver/sqlserver.go
1
package sqlserver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "fmt"
7
        "io"
8
        nurl "net/url"
9
        "strconv"
10
        "strings"
11

12
        "go.uber.org/atomic"
13

14
        "github.com/Azure/go-autorest/autorest/adal"
15
        "github.com/golang-migrate/migrate/v4"
16
        "github.com/golang-migrate/migrate/v4/database"
17
        "github.com/golang-migrate/migrate/v4/source"
18
        "github.com/hashicorp/go-multierror"
19
        mssql "github.com/microsoft/go-mssqldb" // mssql support
20
)
21

22
func init() {
2✔
23
        database.Register("sqlserver", &SQLServer{})
2✔
24
}
2✔
25

26
// DefaultMigrationsTable is the name of the migrations table in the database
27
var DefaultMigrationsTable = "schema_migrations"
28

29
var (
30
        ErrNilConfig                 = fmt.Errorf("no config")
31
        ErrNoDatabaseName            = fmt.Errorf("no database name")
32
        ErrNoSchema                  = fmt.Errorf("no schema")
33
        ErrDatabaseDirty             = fmt.Errorf("database is dirty")
34
        ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
35
)
36

37
var lockErrorMap = map[mssql.ReturnStatus]string{
38
        -1:   "The lock request timed out.",
39
        -2:   "The lock request was canceled.",
40
        -3:   "The lock request was chosen as a deadlock victim.",
41
        -999: "Parameter validation or other call error.",
42
}
43

44
// Config for database
45
type Config struct {
46
        MigrationsTable string
47
        DatabaseName    string
48
        SchemaName      string
49
}
50

51
// SQL Server connection
52
type SQLServer struct {
53
        // Locking and unlocking need to use the same connection
54
        conn     *sql.Conn
55
        db       *sql.DB
56
        isLocked atomic.Bool
57

58
        // Open and WithInstance need to garantuee that config is never nil
59
        config *Config
60
}
61

62
// WithInstance returns a database instance from an already created database connection.
63
//
64
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
65
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
48✔
66
        if config == nil {
48✔
67
                return nil, ErrNilConfig
×
68
        }
×
69

70
        if err := instance.Ping(); err != nil {
60✔
71
                return nil, err
12✔
72
        }
12✔
73

74
        if config.DatabaseName == "" {
72✔
75
                query := `SELECT DB_NAME()`
36✔
76
                var databaseName string
36✔
77
                if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
36✔
78
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
79
                }
×
80

81
                if len(databaseName) == 0 {
36✔
82
                        return nil, ErrNoDatabaseName
×
83
                }
×
84

85
                config.DatabaseName = databaseName
36✔
86
        }
87

88
        if config.SchemaName == "" {
72✔
89
                query := `SELECT SCHEMA_NAME()`
36✔
90
                var schemaName string
36✔
91
                if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
36✔
92
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
93
                }
×
94

95
                if len(schemaName) == 0 {
36✔
96
                        return nil, ErrNoSchema
×
97
                }
×
98

99
                config.SchemaName = schemaName
36✔
100
        }
101

102
        if len(config.MigrationsTable) == 0 {
72✔
103
                config.MigrationsTable = DefaultMigrationsTable
36✔
104
        }
36✔
105

106
        conn, err := instance.Conn(context.Background())
36✔
107

36✔
108
        if err != nil {
36✔
109
                return nil, err
×
110
        }
×
111

112
        ss := &SQLServer{
36✔
113
                conn:   conn,
36✔
114
                db:     instance,
36✔
115
                config: config,
36✔
116
        }
36✔
117

36✔
118
        if err := ss.ensureVersionTable(); err != nil {
36✔
119
                return nil, err
×
120
        }
×
121

122
        return ss, nil
36✔
123
}
124

125
// Open a connection to the database.
126
func (ss *SQLServer) Open(url string) (database.Driver, error) {
54✔
127
        purl, err := nurl.Parse(url)
54✔
128
        if err != nil {
54✔
129
                return nil, err
×
130
        }
×
131

132
        useMsiParam := purl.Query().Get("useMsi")
54✔
133
        useMsi := false
54✔
134
        if len(useMsiParam) > 0 {
78✔
135
                useMsi, err = strconv.ParseBool(useMsiParam)
24✔
136
                if err != nil {
24✔
137
                        return nil, err
×
138
                }
×
139
        }
140

141
        if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
60✔
142
                return nil, ErrMultipleAuthOptionsPassed
6✔
143
        }
6✔
144

145
        filteredURL := migrate.FilterCustomQuery(purl).String()
48✔
146

48✔
147
        var db *sql.DB
48✔
148
        if useMsi {
54✔
149
                resource := getAADResourceFromServerUri(purl)
6✔
150
                tokenProvider, err := getMSITokenProvider(resource)
6✔
151
                if err != nil {
6✔
152
                        return nil, err
×
153
                }
×
154

155
                connector, err := mssql.NewAccessTokenConnector(
6✔
156
                        filteredURL, tokenProvider)
6✔
157
                if err != nil {
6✔
158
                        return nil, err
×
159
                }
×
160

161
                db = sql.OpenDB(connector)
6✔
162

163
        } else {
42✔
164
                db, err = sql.Open("sqlserver", filteredURL)
42✔
165
                if err != nil {
42✔
166
                        return nil, err
×
167
                }
×
168
        }
169

170
        migrationsTable := purl.Query().Get("x-migrations-table")
48✔
171

48✔
172
        px, err := WithInstance(db, &Config{
48✔
173
                DatabaseName:    purl.Path,
48✔
174
                MigrationsTable: migrationsTable,
48✔
175
        })
48✔
176

48✔
177
        if err != nil {
60✔
178
                return nil, err
12✔
179
        }
12✔
180

181
        return px, nil
36✔
182
}
183

184
// Close the database connection
185
func (ss *SQLServer) Close() error {
30✔
186
        connErr := ss.conn.Close()
30✔
187
        dbErr := ss.db.Close()
30✔
188
        if connErr != nil || dbErr != nil {
30✔
189
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
190
        }
×
191
        return nil
30✔
192
}
193

194
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
195
func (ss *SQLServer) Lock() error {
114✔
196
        return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
210✔
197
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
96✔
198
                if err != nil {
96✔
199
                        return err
×
200
                }
×
201

202
                // This will either obtain the lock immediately and return true,
203
                // or return false if the lock cannot be acquired immediately.
204
                // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
205
                query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
96✔
206

96✔
207
                var status mssql.ReturnStatus
96✔
208
                if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
192✔
209
                        return nil
96✔
210
                } else if err != nil {
96✔
211
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
×
212
                } else {
×
213
                        return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
×
214
                }
×
215
        })
216
}
217

218
// Unlock froms the migration lock from the database
219
func (ss *SQLServer) Unlock() error {
96✔
220
        return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
192✔
221
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
96✔
222
                if err != nil {
96✔
223
                        return err
×
224
                }
×
225

226
                // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
227
                query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
96✔
228
                if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
96✔
229
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
230
                }
×
231

232
                return nil
96✔
233
        })
234
}
235

236
// Run the migrations for the database
237
func (ss *SQLServer) Run(migration io.Reader) error {
84✔
238
        migr, err := io.ReadAll(migration)
84✔
239
        if err != nil {
84✔
240
                return err
×
241
        }
×
242

243
        // run migration
244
        query := string(migr[:])
84✔
245
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
90✔
246
                if msErr, ok := err.(mssql.Error); ok {
12✔
247
                        message := fmt.Sprintf("migration failed: %s", msErr.Message)
6✔
248
                        if msErr.ProcName != "" {
6✔
249
                                message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
×
250
                        }
×
251
                        return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
6✔
252
                }
253
                return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
×
254
        }
255

256
        return nil
78✔
257
}
258

259
// SetVersion for the current database
260
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
216✔
261

216✔
262
        tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
216✔
263
        if err != nil {
216✔
264
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
×
265
        }
×
266

267
        query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
216✔
268
        if _, err := tx.Exec(query); err != nil {
216✔
269
                if errRollback := tx.Rollback(); errRollback != nil {
×
270
                        err = multierror.Append(err, errRollback)
×
271
                }
×
272
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
273
        }
274

275
        // Also re-write the schema version for nil dirty versions to prevent
276
        // empty schema version for failed down migration on the first migration
277
        // See: https://github.com/golang-migrate/migrate/issues/330
278
        if version >= 0 || (version == database.NilVersion && dirty) {
414✔
279
                var dirtyBit int
198✔
280
                if dirty {
306✔
281
                        dirtyBit = 1
108✔
282
                }
108✔
283
                query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
198✔
284
                if _, err := tx.Exec(query, version, dirtyBit); err != nil {
198✔
285
                        if errRollback := tx.Rollback(); errRollback != nil {
×
286
                                err = multierror.Append(err, errRollback)
×
287
                        }
×
288
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
289
                }
290
        }
291

292
        if err := tx.Commit(); err != nil {
216✔
293
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
294
        }
×
295

296
        return nil
216✔
297
}
298

299
// Version of the current database state
300
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
132✔
301
        query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
132✔
302
        err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
132✔
303
        switch {
132✔
304
        case err == sql.ErrNoRows:
42✔
305
                return database.NilVersion, false, nil
42✔
306

307
        case err != nil:
×
308
                // FIXME: convert to MSSQL error
×
309
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
×
310

311
        default:
90✔
312
                return version, dirty, nil
90✔
313
        }
314
}
315

316
// Drop all tables from the database.
317
func (ss *SQLServer) Drop() error {
24✔
318

24✔
319
        // drop all referential integrity constraints
24✔
320
        query := `
24✔
321
        DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
24✔
322

24✔
323
        SET @Cursor = CURSOR FAST_FORWARD FOR
24✔
324
        SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
24✔
325
        FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
24✔
326
        LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
24✔
327

24✔
328
        OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
24✔
329

24✔
330
        WHILE (@@FETCH_STATUS = 0)
24✔
331
        BEGIN
24✔
332
        Exec sp_executesql @Sql
24✔
333
        FETCH NEXT FROM @Cursor INTO @Sql
24✔
334
        END
24✔
335

24✔
336
        CLOSE @Cursor DEALLOCATE @Cursor`
24✔
337

24✔
338
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
24✔
339
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
340
        }
×
341

342
        // drop the tables
343
        query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
24✔
344
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
24✔
345
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
346
        }
×
347

348
        return nil
24✔
349
}
350

351
// Exec implements database.Driver. Executes a migration exectuor.
NEW
352
func (ss *SQLServer) Exec(e source.Executor) error {
×
NEW
353
        return e.Execute(ss.db)
×
NEW
354
}
×
355

356
func (ss *SQLServer) ensureVersionTable() (err error) {
36✔
357
        if err = ss.Lock(); err != nil {
36✔
358
                return err
×
359
        }
×
360

361
        defer func() {
72✔
362
                if e := ss.Unlock(); e != nil {
36✔
363
                        if err == nil {
×
364
                                err = e
×
365
                        } else {
×
366
                                err = multierror.Append(err, e)
×
367
                        }
×
368
                }
369
        }()
370

371
        query := `IF NOT EXISTS
36✔
372
        (SELECT *
36✔
373
                 FROM sysobjects
36✔
374
                WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
36✔
375
                        AND OBJECTPROPERTY(id, N'IsUserTable') = 1
36✔
376
        )
36✔
377
        CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
36✔
378

36✔
379
        if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
36✔
380
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
381
        }
×
382

383
        return nil
36✔
384
}
385

386
func getMSITokenProvider(resource string) (func() (string, error), error) {
6✔
387
        msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
6✔
388
        if err != nil {
6✔
389
                return nil, err
×
390
        }
×
391

392
        return func() (string, error) {
12✔
393
                err := msi.EnsureFresh()
6✔
394
                if err != nil {
12✔
395
                        return "", err
6✔
396
                }
6✔
397
                token := msi.OAuthToken()
×
398
                return token, nil
×
399
        }, nil
400
}
401

402
// The sql server resource can change across clouds so get it
403
// dynamically based on the server uri.
404
// ex. <server name>.database.windows.net -> https://database.windows.net
405
func getAADResourceFromServerUri(purl *nurl.URL) string {
6✔
406
        return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
6✔
407
}
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc