• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 5186663860

pending completion
5186663860

Pull #932

github

longit644
Add Golang function as a source of migration
Pull Request #932: Add Golang function as a source of migration

252 of 370 new or added lines in 32 files covered. (68.11%)

4 existing lines in 4 files now uncovered.

4214 of 7231 relevant lines covered (58.28%)

61.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.02
/database/postgres/postgres.go
1
//go:build go1.9
2
// +build go1.9
3

4
package postgres
5

6
import (
7
        "context"
8
        "database/sql"
9
        "fmt"
10
        "io"
11
        nurl "net/url"
12
        "regexp"
13
        "strconv"
14
        "strings"
15
        "time"
16

17
        "go.uber.org/atomic"
18

19
        "github.com/golang-migrate/migrate/v4"
20
        "github.com/golang-migrate/migrate/v4/database"
21
        "github.com/golang-migrate/migrate/v4/database/multistmt"
22
        "github.com/golang-migrate/migrate/v4/source"
23
        "github.com/hashicorp/go-multierror"
24
        "github.com/lib/pq"
25
)
26

27
func init() {
2✔
28
        db := Postgres{}
2✔
29
        database.Register("postgres", &db)
2✔
30
        database.Register("postgresql", &db)
2✔
31
}
2✔
32

33
var (
34
        multiStmtDelimiter = []byte(";")
35

36
        DefaultMigrationsTable       = "schema_migrations"
37
        DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
38
)
39

40
var (
41
        ErrNilConfig      = fmt.Errorf("no config")
42
        ErrNoDatabaseName = fmt.Errorf("no database name")
43
        ErrNoSchema       = fmt.Errorf("no schema")
44
        ErrDatabaseDirty  = fmt.Errorf("database is dirty")
45
)
46

47
type Config struct {
48
        MigrationsTable       string
49
        MigrationsTableQuoted bool
50
        MultiStatementEnabled bool
51
        DatabaseName          string
52
        SchemaName            string
53
        migrationsSchemaName  string
54
        migrationsTableName   string
55
        StatementTimeout      time.Duration
56
        MultiStatementMaxSize int
57
}
58

59
type Postgres struct {
60
        // Locking and unlocking need to use the same connection
61
        conn     *sql.Conn
62
        db       *sql.DB
63
        isLocked atomic.Bool
64

65
        // Open and WithInstance need to guarantee that config is never nil
66
        config *Config
67
}
68

69
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
530✔
70
        if config == nil {
530✔
71
                return nil, ErrNilConfig
×
72
        }
×
73

74
        if err := conn.PingContext(ctx); err != nil {
530✔
75
                return nil, err
×
76
        }
×
77

78
        if config.DatabaseName == "" {
840✔
79
                query := `SELECT CURRENT_DATABASE()`
310✔
80
                var databaseName string
310✔
81
                if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
310✔
82
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
83
                }
×
84

85
                if len(databaseName) == 0 {
310✔
86
                        return nil, ErrNoDatabaseName
×
87
                }
×
88

89
                config.DatabaseName = databaseName
310✔
90
        }
91

92
        if config.SchemaName == "" {
1,060✔
93
                query := `SELECT CURRENT_SCHEMA()`
530✔
94
                var schemaName sql.NullString
530✔
95
                if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
530✔
96
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
97
                }
×
98

99
                if !schemaName.Valid {
530✔
100
                        return nil, ErrNoSchema
×
101
                }
×
102

103
                config.SchemaName = schemaName.String
530✔
104
        }
105

106
        if len(config.MigrationsTable) == 0 {
1,020✔
107
                config.MigrationsTable = DefaultMigrationsTable
490✔
108
        }
490✔
109

110
        config.migrationsSchemaName = config.SchemaName
530✔
111
        config.migrationsTableName = config.MigrationsTable
530✔
112
        if config.MigrationsTableQuoted {
560✔
113
                re := regexp.MustCompile(`"(.*?)"`)
30✔
114
                result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
30✔
115
                config.migrationsTableName = result[len(result)-1][1]
30✔
116
                if len(result) == 2 {
50✔
117
                        config.migrationsSchemaName = result[0][1]
20✔
118
                } else if len(result) > 2 {
40✔
119
                        return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
10✔
120
                }
10✔
121
        }
122

123
        px := &Postgres{
520✔
124
                conn:   conn,
520✔
125
                config: config,
520✔
126
        }
520✔
127

520✔
128
        if err := px.ensureVersionTable(); err != nil {
540✔
129
                return nil, err
20✔
130
        }
20✔
131

132
        return px, nil
500✔
133
}
134

135
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
520✔
136
        ctx := context.Background()
520✔
137

520✔
138
        if err := instance.Ping(); err != nil {
520✔
139
                return nil, err
×
140
        }
×
141

142
        conn, err := instance.Conn(ctx)
520✔
143
        if err != nil {
520✔
144
                return nil, err
×
145
        }
×
146

147
        px, err := WithConnection(ctx, conn, config)
520✔
148
        if err != nil {
550✔
149
                return nil, err
30✔
150
        }
30✔
151
        px.db = instance
490✔
152
        return px, nil
490✔
153
}
154

155
func (p *Postgres) Open(url string) (database.Driver, error) {
230✔
156
        purl, err := nurl.Parse(url)
230✔
157
        if err != nil {
230✔
158
                return nil, err
×
159
        }
×
160

161
        db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
230✔
162
        if err != nil {
230✔
163
                return nil, err
×
164
        }
×
165

166
        migrationsTable := purl.Query().Get("x-migrations-table")
230✔
167
        migrationsTableQuoted := false
230✔
168
        if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
270✔
169
                migrationsTableQuoted, err = strconv.ParseBool(s)
40✔
170
                if err != nil {
40✔
171
                        return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
×
172
                }
×
173
        }
174
        if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
240✔
175
                return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
10✔
176
        }
10✔
177

178
        statementTimeoutString := purl.Query().Get("x-statement-timeout")
220✔
179
        statementTimeout := 0
220✔
180
        if statementTimeoutString != "" {
220✔
181
                statementTimeout, err = strconv.Atoi(statementTimeoutString)
×
182
                if err != nil {
×
183
                        return nil, err
×
184
                }
×
185
        }
186

187
        multiStatementMaxSize := DefaultMultiStatementMaxSize
220✔
188
        if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
220✔
189
                multiStatementMaxSize, err = strconv.Atoi(s)
×
190
                if err != nil {
×
191
                        return nil, err
×
192
                }
×
193
                if multiStatementMaxSize <= 0 {
×
194
                        multiStatementMaxSize = DefaultMultiStatementMaxSize
×
195
                }
×
196
        }
197

198
        multiStatementEnabled := false
220✔
199
        if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
230✔
200
                multiStatementEnabled, err = strconv.ParseBool(s)
10✔
201
                if err != nil {
10✔
202
                        return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
×
203
                }
×
204
        }
205

206
        px, err := WithInstance(db, &Config{
220✔
207
                DatabaseName:          purl.Path,
220✔
208
                MigrationsTable:       migrationsTable,
220✔
209
                MigrationsTableQuoted: migrationsTableQuoted,
220✔
210
                StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
220✔
211
                MultiStatementEnabled: multiStatementEnabled,
220✔
212
                MultiStatementMaxSize: multiStatementMaxSize,
220✔
213
        })
220✔
214

220✔
215
        if err != nil {
250✔
216
                return nil, err
30✔
217
        }
30✔
218

219
        return px, nil
190✔
220
}
221

222
func (p *Postgres) Close() error {
170✔
223
        connErr := p.conn.Close()
170✔
224
        var dbErr error
170✔
225
        if p.db != nil {
330✔
226
                dbErr = p.db.Close()
160✔
227
        }
160✔
228

229
        if connErr != nil || dbErr != nil {
170✔
230
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
231
        }
×
232
        return nil
170✔
233
}
234

235
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
236
func (p *Postgres) Lock() error {
670✔
237
        return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
1,310✔
238
                aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
640✔
239
                if err != nil {
640✔
240
                        return err
×
241
                }
×
242

243
                // This will wait indefinitely until the lock can be acquired.
244
                query := `SELECT pg_advisory_lock($1)`
640✔
245
                if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
640✔
246
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
×
247
                }
×
248

249
                return nil
640✔
250
        })
251
}
252

253
func (p *Postgres) Unlock() error {
640✔
254
        return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
1,280✔
255
                aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
640✔
256
                if err != nil {
640✔
257
                        return err
×
258
                }
×
259

260
                query := `SELECT pg_advisory_unlock($1)`
640✔
261
                if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
640✔
262
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
263
                }
×
264
                return nil
640✔
265
        })
266
}
267

268
func (p *Postgres) Run(migration io.Reader) error {
300✔
269
        if p.config.MultiStatementEnabled {
310✔
270
                var err error
10✔
271
                if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
30✔
272
                        if err = p.runStatement(m); err != nil {
20✔
273
                                return false
×
274
                        }
×
275
                        return true
20✔
276
                }); e != nil {
×
277
                        return e
×
278
                }
×
279
                return err
10✔
280
        }
281
        migr, err := io.ReadAll(migration)
290✔
282
        if err != nil {
290✔
283
                return err
×
284
        }
×
285
        return p.runStatement(migr)
290✔
286
}
287

288
func (p *Postgres) runStatement(statement []byte) error {
310✔
289
        ctx := context.Background()
310✔
290
        if p.config.StatementTimeout != 0 {
310✔
291
                var cancel context.CancelFunc
×
292
                ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
×
293
                defer cancel()
×
294
        }
×
295
        query := string(statement)
310✔
296
        if strings.TrimSpace(query) == "" {
310✔
297
                return nil
×
298
        }
×
299
        if _, err := p.conn.ExecContext(ctx, query); err != nil {
320✔
300
                if pgErr, ok := err.(*pq.Error); ok {
20✔
301
                        var line uint
10✔
302
                        var col uint
10✔
303
                        var lineColOK bool
10✔
304
                        if pgErr.Position != "" {
20✔
305
                                if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
20✔
306
                                        line, col, lineColOK = computeLineFromPos(query, int(pos))
10✔
307
                                }
10✔
308
                        }
309
                        message := fmt.Sprintf("migration failed: %s", pgErr.Message)
10✔
310
                        if lineColOK {
20✔
311
                                message = fmt.Sprintf("%s (column %d)", message, col)
10✔
312
                        }
10✔
313
                        if pgErr.Detail != "" {
10✔
314
                                message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
×
315
                        }
×
316
                        return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
10✔
317
                }
318
                return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
×
319
        }
320
        return nil
300✔
321
}
322

323
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
74✔
324
        // replace crlf with lf
74✔
325
        s = strings.Replace(s, "\r\n", "\n", -1)
74✔
326
        // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
74✔
327
        runes := []rune(s)
74✔
328
        if pos > len(runes) {
82✔
329
                return 0, 0, false
8✔
330
        }
8✔
331
        sel := runes[:pos]
66✔
332
        line = uint(runesCount(sel, newLine) + 1)
66✔
333
        col = uint(pos - 1 - runesLastIndex(sel, newLine))
66✔
334
        return line, col, true
66✔
335
}
336

337
const newLine = '\n'
338

339
func runesCount(input []rune, target rune) int {
66✔
340
        var count int
66✔
341
        for _, r := range input {
1,404✔
342
                if r == target {
1,442✔
343
                        count++
104✔
344
                }
104✔
345
        }
346
        return count
66✔
347
}
348

349
func runesLastIndex(input []rune, target rune) int {
66✔
350
        for i := len(input) - 1; i >= 0; i-- {
780✔
351
                if input[i] == target {
770✔
352
                        return i
56✔
353
                }
56✔
354
        }
355
        return -1
10✔
356
}
357

358
func (p *Postgres) SetVersion(version int, dirty bool) error {
380✔
359
        tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
380✔
360
        if err != nil {
380✔
361
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
×
362
        }
×
363

364
        query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
380✔
365
        if _, err := tx.Exec(query); err != nil {
380✔
366
                if errRollback := tx.Rollback(); errRollback != nil {
×
367
                        err = multierror.Append(err, errRollback)
×
368
                }
×
369
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
370
        }
371

372
        // Also re-write the schema version for nil dirty versions to prevent
373
        // empty schema version for failed down migration on the first migration
374
        // See: https://github.com/golang-migrate/migrate/issues/330
375
        if version >= 0 || (version == database.NilVersion && dirty) {
730✔
376
                query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
350✔
377
                if _, err := tx.Exec(query, version, dirty); err != nil {
350✔
378
                        if errRollback := tx.Rollback(); errRollback != nil {
×
379
                                err = multierror.Append(err, errRollback)
×
380
                        }
×
381
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
382
                }
383
        }
384

385
        if err := tx.Commit(); err != nil {
380✔
386
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
387
        }
×
388

389
        return nil
380✔
390
}
391

392
func (p *Postgres) Version() (version int, dirty bool, err error) {
260✔
393
        query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
260✔
394
        err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
260✔
395
        switch {
260✔
396
        case err == sql.ErrNoRows:
90✔
397
                return database.NilVersion, false, nil
90✔
398

399
        case err != nil:
×
400
                if e, ok := err.(*pq.Error); ok {
×
401
                        if e.Code.Name() == "undefined_table" {
×
402
                                return database.NilVersion, false, nil
×
403
                        }
×
404
                }
405
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
×
406

407
        default:
170✔
408
                return version, dirty, nil
170✔
409
        }
410
}
411

412
func (p *Postgres) Drop() (err error) {
40✔
413
        // select all tables in current schema
40✔
414
        query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
40✔
415
        tables, err := p.conn.QueryContext(context.Background(), query)
40✔
416
        if err != nil {
40✔
417
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
418
        }
×
419
        defer func() {
80✔
420
                if errClose := tables.Close(); errClose != nil {
40✔
421
                        err = multierror.Append(err, errClose)
×
422
                }
×
423
        }()
424

425
        // delete one table after another
426
        tableNames := make([]string, 0)
40✔
427
        for tables.Next() {
110✔
428
                var tableName string
70✔
429
                if err := tables.Scan(&tableName); err != nil {
70✔
430
                        return err
×
431
                }
×
432
                if len(tableName) > 0 {
140✔
433
                        tableNames = append(tableNames, tableName)
70✔
434
                }
70✔
435
        }
436
        if err := tables.Err(); err != nil {
40✔
437
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
438
        }
×
439

440
        if len(tableNames) > 0 {
80✔
441
                // delete one by one ...
40✔
442
                for _, t := range tableNames {
110✔
443
                        query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
70✔
444
                        if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
70✔
445
                                return &database.Error{OrigErr: err, Query: []byte(query)}
×
446
                        }
×
447
                }
448
        }
449

450
        return nil
40✔
451
}
452

NEW
453
func (p *Postgres) Exec(e source.Executor) error {
×
NEW
454
        return e.Execute(p.db)
×
NEW
455
}
×
456

457
// ensureVersionTable checks if versions table exists and, if not, creates it.
458
// Note that this function locks the database, which deviates from the usual
459
// convention of "caller locks" in the Postgres type.
460
func (p *Postgres) ensureVersionTable() (err error) {
520✔
461
        if err = p.Lock(); err != nil {
520✔
462
                return err
×
463
        }
×
464

465
        defer func() {
1,040✔
466
                if e := p.Unlock(); e != nil {
520✔
467
                        if err == nil {
×
468
                                err = e
×
469
                        } else {
×
470
                                err = multierror.Append(err, e)
×
471
                        }
×
472
                }
473
        }()
474

475
        // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
476
        // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
477
        // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
478
        // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
479
        query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
520✔
480
        row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
520✔
481

520✔
482
        var count int
520✔
483
        err = row.Scan(&count)
520✔
484
        if err != nil {
520✔
485
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
486
        }
×
487

488
        if count == 1 {
820✔
489
                return nil
300✔
490
        }
300✔
491

492
        query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
220✔
493
        if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
240✔
494
                return &database.Error{OrigErr: err, Query: []byte(query)}
20✔
495
        }
20✔
496

497
        return nil
200✔
498
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc