• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 15511084249

16 May 2024 02:12PM UTC coverage: 56.314%. Remained the same
15511084249

Pull #1087

github

ccoVeille
chore: fix typos, acronym and styles
Pull Request #1087: chore: fix typos, acronym and styles

3 of 12 new or added lines in 3 files covered. (25.0%)

1 existing line in 1 file now uncovered.

4562 of 8101 relevant lines covered (56.31%)

50.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.43
/database/sqlserver/sqlserver.go
1
package sqlserver
2

3
import (
4
        "context"
5
        "database/sql"
6
        "fmt"
7
        "io"
8
        nurl "net/url"
9
        "strconv"
10
        "strings"
11

12
        "go.uber.org/atomic"
13

14
        "github.com/Azure/go-autorest/autorest/adal"
15
        "github.com/golang-migrate/migrate/v4"
16
        "github.com/golang-migrate/migrate/v4/database"
17
        "github.com/hashicorp/go-multierror"
18
        mssql "github.com/microsoft/go-mssqldb" // mssql support
19
)
20

21
func init() {
2✔
22
        database.Register("sqlserver", &SQLServer{})
2✔
23
}
2✔
24

25
// DefaultMigrationsTable is the name of the migrations table in the database
26
var DefaultMigrationsTable = "schema_migrations"
27

28
var (
29
        ErrNilConfig                 = fmt.Errorf("no config")
30
        ErrNoDatabaseName            = fmt.Errorf("no database name")
31
        ErrNoSchema                  = fmt.Errorf("no schema")
32
        ErrDatabaseDirty             = fmt.Errorf("database is dirty")
33
        ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
34
)
35

36
var lockErrorMap = map[mssql.ReturnStatus]string{
37
        -1:   "The lock request timed out.",
38
        -2:   "The lock request was canceled.",
39
        -3:   "The lock request was chosen as a deadlock victim.",
40
        -999: "Parameter validation or other call error.",
41
}
42

43
// Config for database
44
type Config struct {
45
        MigrationsTable string
46
        DatabaseName    string
47
        SchemaName      string
48
}
49

50
// SQL Server connection
51
type SQLServer struct {
52
        // Locking and unlocking need to use the same connection
53
        conn     *sql.Conn
54
        db       *sql.DB
55
        isLocked atomic.Bool
56

57
        // Open and WithInstance need to garantuee that config is never nil
58
        config *Config
59
}
60

61
// WithInstance returns a database instance from an already created database connection.
62
//
63
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
64
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
32✔
65
        if config == nil {
32✔
66
                return nil, ErrNilConfig
×
67
        }
×
68

69
        if err := instance.Ping(); err != nil {
40✔
70
                return nil, err
8✔
71
        }
8✔
72

73
        if config.DatabaseName == "" {
48✔
74
                query := `SELECT DB_NAME()`
24✔
75
                var databaseName string
24✔
76
                if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
24✔
77
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
78
                }
×
79

80
                if len(databaseName) == 0 {
24✔
81
                        return nil, ErrNoDatabaseName
×
82
                }
×
83

84
                config.DatabaseName = databaseName
24✔
85
        }
86

87
        if config.SchemaName == "" {
48✔
88
                query := `SELECT SCHEMA_NAME()`
24✔
89
                var schemaName string
24✔
90
                if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
24✔
91
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
92
                }
×
93

94
                if len(schemaName) == 0 {
24✔
95
                        return nil, ErrNoSchema
×
96
                }
×
97

98
                config.SchemaName = schemaName
24✔
99
        }
100

101
        if len(config.MigrationsTable) == 0 {
48✔
102
                config.MigrationsTable = DefaultMigrationsTable
24✔
103
        }
24✔
104

105
        conn, err := instance.Conn(context.Background())
24✔
106

24✔
107
        if err != nil {
24✔
108
                return nil, err
×
109
        }
×
110

111
        ss := &SQLServer{
24✔
112
                conn:   conn,
24✔
113
                db:     instance,
24✔
114
                config: config,
24✔
115
        }
24✔
116

24✔
117
        if err := ss.ensureVersionTable(); err != nil {
24✔
118
                return nil, err
×
119
        }
×
120

121
        return ss, nil
24✔
122
}
123

124
// Open a connection to the database.
125
func (ss *SQLServer) Open(url string) (database.Driver, error) {
36✔
126
        purl, err := nurl.Parse(url)
36✔
127
        if err != nil {
36✔
128
                return nil, err
×
129
        }
×
130

131
        useMsiParam := purl.Query().Get("useMsi")
36✔
132
        useMsi := false
36✔
133
        if len(useMsiParam) > 0 {
52✔
134
                useMsi, err = strconv.ParseBool(useMsiParam)
16✔
135
                if err != nil {
16✔
136
                        return nil, err
×
137
                }
×
138
        }
139

140
        if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
40✔
141
                return nil, ErrMultipleAuthOptionsPassed
4✔
142
        }
4✔
143

144
        filteredURL := migrate.FilterCustomQuery(purl).String()
32✔
145

32✔
146
        var db *sql.DB
32✔
147
        if useMsi {
36✔
148
                resource := getAADResourceFromServerUri(purl)
4✔
149
                tokenProvider, err := getMSITokenProvider(resource)
4✔
150
                if err != nil {
4✔
151
                        return nil, err
×
152
                }
×
153

154
                connector, err := mssql.NewAccessTokenConnector(
4✔
155
                        filteredURL, tokenProvider)
4✔
156
                if err != nil {
4✔
157
                        return nil, err
×
158
                }
×
159

160
                db = sql.OpenDB(connector)
4✔
161

162
        } else {
28✔
163
                db, err = sql.Open("sqlserver", filteredURL)
28✔
164
                if err != nil {
28✔
165
                        return nil, err
×
166
                }
×
167
        }
168

169
        migrationsTable := purl.Query().Get("x-migrations-table")
32✔
170

32✔
171
        px, err := WithInstance(db, &Config{
32✔
172
                DatabaseName:    purl.Path,
32✔
173
                MigrationsTable: migrationsTable,
32✔
174
        })
32✔
175

32✔
176
        if err != nil {
40✔
177
                return nil, err
8✔
178
        }
8✔
179

180
        return px, nil
24✔
181
}
182

183
// Close the database connection
184
func (ss *SQLServer) Close() error {
20✔
185
        connErr := ss.conn.Close()
20✔
186
        dbErr := ss.db.Close()
20✔
187
        if connErr != nil || dbErr != nil {
20✔
188
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
189
        }
×
190
        return nil
20✔
191
}
192

193
// Lock creates an advisory lock on the database to prevent multiple migrations from running at the same time.
194
func (ss *SQLServer) Lock() error {
76✔
195
        return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
140✔
196
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
64✔
197
                if err != nil {
64✔
198
                        return err
×
199
                }
×
200

201
                // This will either obtain the lock immediately and return true,
202
                // or return false if the lock cannot be acquired immediately.
203
                // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
64✔
204
                query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
64✔
205

64✔
206
                var status mssql.ReturnStatus
64✔
207
                if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
64✔
208
                        return nil
64✔
209
                } else if err != nil {
128✔
210
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
64✔
211
                } else {
64✔
212
                        return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
×
213
                }
×
214
        })
×
215
}
×
216

×
NEW
217
// Unlock the migration lock from the database
×
218
func (ss *SQLServer) Unlock() error {
×
219
        return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
220
                aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
221
                if err != nil {
222
                        return err
223
                }
224

64✔
225
                // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
128✔
226
                query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
64✔
227
                if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
64✔
228
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
229
                }
×
230

231
                return nil
232
        })
64✔
233
}
64✔
234

×
235
// Run the migrations for the database
×
236
func (ss *SQLServer) Run(migration io.Reader) error {
237
        migr, err := io.ReadAll(migration)
64✔
238
        if err != nil {
239
                return err
240
        }
241

242
        // run migration
56✔
243
        query := string(migr[:])
56✔
244
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
56✔
245
                if msErr, ok := err.(mssql.Error); ok {
×
246
                        message := fmt.Sprintf("migration failed: %s", msErr.Message)
×
247
                        if msErr.ProcName != "" {
248
                                message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
249
                        }
56✔
250
                        return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
60✔
251
                }
8✔
252
                return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
4✔
253
        }
4✔
254

×
255
        return nil
×
256
}
4✔
257

258
// SetVersion for the current database
×
259
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
260

261
        tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
52✔
262
        if err != nil {
263
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
264
        }
265

144✔
266
        query := `TRUNCATE TABLE ` + ss.getMigrationTable()
144✔
267
        if _, err := tx.Exec(query); err != nil {
144✔
268
                if errRollback := tx.Rollback(); errRollback != nil {
144✔
269
                        err = multierror.Append(err, errRollback)
×
270
                }
×
271
                return &database.Error{OrigErr: err, Query: []byte(query)}
272
        }
144✔
273

144✔
274
        // Also re-write the schema version for nil dirty versions to prevent
×
275
        // empty schema version for failed down migration on the first migration
×
276
        // See: https://github.com/golang-migrate/migrate/issues/330
×
277
        if version >= 0 || (version == database.NilVersion && dirty) {
×
278
                var dirtyBit int
279
                if dirty {
280
                        dirtyBit = 1
281
                }
282
                query = `INSERT INTO ` + ss.getMigrationTable() + ` (version, dirty) VALUES (@p1, @p2)`
283
                if _, err := tx.Exec(query, version, dirtyBit); err != nil {
276✔
284
                        if errRollback := tx.Rollback(); errRollback != nil {
132✔
285
                                err = multierror.Append(err, errRollback)
204✔
286
                        }
72✔
287
                        return &database.Error{OrigErr: err, Query: []byte(query)}
72✔
288
                }
132✔
289
        }
132✔
290

×
291
        if err := tx.Commit(); err != nil {
×
292
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
293
        }
×
294

295
        return nil
296
}
297

144✔
298
// Version of the current database state
×
299
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
×
300
        query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable()
301
        err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
144✔
302
        switch {
303
        case err == sql.ErrNoRows:
304
                return database.NilVersion, false, nil
305

88✔
306
        case err != nil:
88✔
307
                // FIXME: convert to MSSQL error
88✔
308
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
88✔
309

28✔
310
        default:
28✔
311
                return version, dirty, nil
312
        }
×
313
}
×
314

×
315
// Drop all tables from the database.
316
func (ss *SQLServer) Drop() error {
60✔
317

60✔
318
        // drop all referential integrity constraints
319
        query := `
320
        DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
321

322
        SET @Cursor = CURSOR FAST_FORWARD FOR
16✔
323
        SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
16✔
324
        FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
16✔
325
        LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
16✔
326

16✔
327
        OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
16✔
328

16✔
329
        WHILE (@@FETCH_STATUS = 0)
16✔
330
        BEGIN
16✔
331
        Exec sp_executesql @Sql
16✔
332
        FETCH NEXT FROM @Cursor INTO @Sql
16✔
333
        END
16✔
334

16✔
335
        CLOSE @Cursor DEALLOCATE @Cursor`
16✔
336

16✔
337
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
16✔
338
                return &database.Error{OrigErr: err, Query: []byte(query)}
16✔
339
        }
16✔
340

16✔
341
        // drop the tables
16✔
342
        query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
16✔
343
        if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
16✔
344
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
345
        }
×
346

347
        return nil
348
}
16✔
349

16✔
350
func (ss *SQLServer) ensureVersionTable() (err error) {
×
351
        if err = ss.Lock(); err != nil {
×
352
                return err
353
        }
16✔
354

355
        defer func() {
356
                if e := ss.Unlock(); e != nil {
24✔
357
                        if err == nil {
24✔
358
                                err = e
×
359
                        } else {
×
360
                                err = multierror.Append(err, e)
361
                        }
48✔
362
                }
24✔
363
        }()
×
364

×
365
        query := `IF NOT EXISTS
×
366
        (SELECT *
×
367
                 FROM sysobjects
×
368
                WHERE id = object_id(N'` + ss.getMigrationTable() + `')
369
                        AND OBJECTPROPERTY(id, N'IsUserTable') = 1
370
        )
371
        CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
24✔
372

24✔
373
        if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
24✔
374
                return &database.Error{OrigErr: err, Query: []byte(query)}
24✔
375
        }
24✔
376

24✔
377
        return nil
24✔
378
}
24✔
379

24✔
380
func (ss *SQLServer) getMigrationTable() string {
×
381
        return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable)
×
382
}
383

24✔
384
func getMSITokenProvider(resource string) (func() (string, error), error) {
385
        msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
386
        if err != nil {
412✔
387
                return nil, err
412✔
388
        }
412✔
389

390
        return func() (string, error) {
4✔
391
                err := msi.EnsureFresh()
4✔
392
                if err != nil {
4✔
393
                        return "", err
×
394
                }
×
395
                token := msi.OAuthToken()
396
                return token, nil
8✔
397
        }, nil
4✔
398
}
8✔
399

4✔
400
// The sql server resource can change across clouds so get it
4✔
401
// dynamically based on the server uri.
×
402
// ex. <server name>.database.windows.net -> https://database.windows.net
×
403
func getAADResourceFromServerUri(purl *nurl.URL) string {
404
        return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
405
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc