• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 17178144630

23 Aug 2025 05:20PM UTC coverage: 56.309% (-0.005%) from 56.314%
17178144630

Pull #1308

github

chandrakant-cohesity
Fixed defer block
Pull Request #1308: Ensure bufferWriter is always closed in Migration.Buffer and propagate close errors

7 of 11 new or added lines in 1 file covered. (63.64%)

557 existing lines in 16 files now uncovered.

4561 of 8100 relevant lines covered (56.31%)

50.12 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.66
/database/postgres/postgres.go
1
//go:build go1.9
2

3
package postgres
4

5
import (
6
        "context"
7
        "database/sql"
8
        "fmt"
9
        "io"
10
        nurl "net/url"
11
        "regexp"
12
        "strconv"
13
        "strings"
14
        "time"
15

16
        "go.uber.org/atomic"
17

18
        "github.com/golang-migrate/migrate/v4"
19
        "github.com/golang-migrate/migrate/v4/database"
20
        "github.com/golang-migrate/migrate/v4/database/multistmt"
21
        "github.com/hashicorp/go-multierror"
22
        "github.com/lib/pq"
23
)
24

2✔
25
func init() {
2✔
26
        db := Postgres{}
2✔
27
        database.Register("postgres", &db)
2✔
28
        database.Register("postgresql", &db)
2✔
29
}
30

31
var (
32
        multiStmtDelimiter = []byte(";")
33

34
        DefaultMigrationsTable       = "schema_migrations"
35
        DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
36
)
37

38
var (
39
        ErrNilConfig      = fmt.Errorf("no config")
40
        ErrNoDatabaseName = fmt.Errorf("no database name")
41
        ErrNoSchema       = fmt.Errorf("no schema")
42
        ErrDatabaseDirty  = fmt.Errorf("database is dirty")
43
)
44

45
type Config struct {
46
        MigrationsTable       string
47
        MigrationsTableQuoted bool
48
        MultiStatementEnabled bool
49
        DatabaseName          string
50
        SchemaName            string
51
        migrationsSchemaName  string
52
        migrationsTableName   string
53
        StatementTimeout      time.Duration
54
        MultiStatementMaxSize int
55
}
56

57
type Postgres struct {
58
        // Locking and unlocking need to use the same connection
59
        conn     *sql.Conn
60
        db       *sql.DB
61
        isLocked atomic.Bool
62

63
        // Open and WithInstance need to guarantee that config is never nil
64
        config *Config
65
}
66

530✔
67
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
530✔
UNCOV
68
        if config == nil {
×
69
                return nil, ErrNilConfig
×
70
        }
71

530✔
UNCOV
72
        if err := conn.PingContext(ctx); err != nil {
×
73
                return nil, err
×
74
        }
75

840✔
76
        if config.DatabaseName == "" {
310✔
77
                query := `SELECT CURRENT_DATABASE()`
310✔
78
                var databaseName string
310✔
UNCOV
79
                if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
×
80
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
81
                }
82

310✔
UNCOV
83
                if len(databaseName) == 0 {
×
84
                        return nil, ErrNoDatabaseName
×
85
                }
86

310✔
87
                config.DatabaseName = databaseName
88
        }
89

1,060✔
90
        if config.SchemaName == "" {
530✔
91
                query := `SELECT CURRENT_SCHEMA()`
530✔
92
                var schemaName sql.NullString
530✔
UNCOV
93
                if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
×
94
                        return nil, &database.Error{OrigErr: err, Query: []byte(query)}
×
95
                }
96

530✔
UNCOV
97
                if !schemaName.Valid {
×
98
                        return nil, ErrNoSchema
×
99
                }
100

530✔
101
                config.SchemaName = schemaName.String
102
        }
103

1,020✔
104
        if len(config.MigrationsTable) == 0 {
490✔
105
                config.MigrationsTable = DefaultMigrationsTable
490✔
106
        }
107

530✔
108
        config.migrationsSchemaName = config.SchemaName
530✔
109
        config.migrationsTableName = config.MigrationsTable
560✔
110
        if config.MigrationsTableQuoted {
30✔
111
                re := regexp.MustCompile(`"(.*?)"`)
30✔
112
                result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
30✔
113
                config.migrationsTableName = result[len(result)-1][1]
50✔
114
                if len(result) == 2 {
20✔
115
                        config.migrationsSchemaName = result[0][1]
40✔
116
                } else if len(result) > 2 {
10✔
117
                        return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
10✔
118
                }
119
        }
120

520✔
121
        px := &Postgres{
520✔
122
                conn:   conn,
520✔
123
                config: config,
520✔
124
        }
520✔
125

540✔
126
        if err := px.ensureVersionTable(); err != nil {
20✔
127
                return nil, err
20✔
128
        }
129

500✔
130
        return px, nil
131
}
132

520✔
133
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
520✔
134
        ctx := context.Background()
520✔
135

520✔
UNCOV
136
        if err := instance.Ping(); err != nil {
×
137
                return nil, err
×
138
        }
139

520✔
140
        conn, err := instance.Conn(ctx)
520✔
UNCOV
141
        if err != nil {
×
142
                return nil, err
×
143
        }
144

520✔
145
        px, err := WithConnection(ctx, conn, config)
550✔
146
        if err != nil {
30✔
147
                return nil, err
30✔
148
        }
490✔
149
        px.db = instance
490✔
150
        return px, nil
151
}
152

230✔
153
func (p *Postgres) Open(url string) (database.Driver, error) {
230✔
154
        purl, err := nurl.Parse(url)
230✔
UNCOV
155
        if err != nil {
×
156
                return nil, err
×
157
        }
158

230✔
159
        db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
230✔
UNCOV
160
        if err != nil {
×
161
                return nil, err
×
162
        }
163

230✔
164
        migrationsTable := purl.Query().Get("x-migrations-table")
230✔
165
        migrationsTableQuoted := false
270✔
166
        if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
40✔
167
                migrationsTableQuoted, err = strconv.ParseBool(s)
40✔
UNCOV
168
                if err != nil {
×
169
                        return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
×
170
                }
171
        }
240✔
172
        if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
10✔
173
                return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
10✔
174
        }
175

220✔
176
        statementTimeoutString := purl.Query().Get("x-statement-timeout")
220✔
177
        statementTimeout := 0
220✔
UNCOV
178
        if statementTimeoutString != "" {
×
179
                statementTimeout, err = strconv.Atoi(statementTimeoutString)
×
180
                if err != nil {
×
181
                        return nil, err
×
182
                }
183
        }
184

220✔
185
        multiStatementMaxSize := DefaultMultiStatementMaxSize
220✔
UNCOV
186
        if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
×
187
                multiStatementMaxSize, err = strconv.Atoi(s)
×
188
                if err != nil {
×
189
                        return nil, err
×
190
                }
×
191
                if multiStatementMaxSize <= 0 {
×
192
                        multiStatementMaxSize = DefaultMultiStatementMaxSize
×
193
                }
194
        }
195

220✔
196
        multiStatementEnabled := false
230✔
197
        if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
10✔
198
                multiStatementEnabled, err = strconv.ParseBool(s)
10✔
UNCOV
199
                if err != nil {
×
200
                        return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
×
201
                }
202
        }
203

220✔
204
        px, err := WithInstance(db, &Config{
220✔
205
                DatabaseName:          purl.Path,
220✔
206
                MigrationsTable:       migrationsTable,
220✔
207
                MigrationsTableQuoted: migrationsTableQuoted,
220✔
208
                StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
220✔
209
                MultiStatementEnabled: multiStatementEnabled,
220✔
210
                MultiStatementMaxSize: multiStatementMaxSize,
220✔
211
        })
220✔
212

250✔
213
        if err != nil {
30✔
214
                return nil, err
30✔
215
        }
216

190✔
217
        return px, nil
218
}
219

170✔
220
func (p *Postgres) Close() error {
170✔
221
        connErr := p.conn.Close()
170✔
222
        var dbErr error
330✔
223
        if p.db != nil {
160✔
224
                dbErr = p.db.Close()
160✔
225
        }
226

170✔
UNCOV
227
        if connErr != nil || dbErr != nil {
×
228
                return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
×
229
        }
170✔
230
        return nil
231
}
232

233
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
670✔
234
func (p *Postgres) Lock() error {
1,310✔
235
        return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
640✔
236
                aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
640✔
UNCOV
237
                if err != nil {
×
238
                        return err
×
239
                }
240

241
                // This will wait indefinitely until the lock can be acquired.
640✔
242
                query := `SELECT pg_advisory_lock($1)`
640✔
UNCOV
243
                if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
×
244
                        return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
×
245
                }
246

640✔
247
                return nil
248
        })
249
}
250

640✔
251
func (p *Postgres) Unlock() error {
1,280✔
252
        return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
640✔
253
                aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
640✔
UNCOV
254
                if err != nil {
×
255
                        return err
×
256
                }
257

640✔
258
                query := `SELECT pg_advisory_unlock($1)`
640✔
UNCOV
259
                if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
×
260
                        return &database.Error{OrigErr: err, Query: []byte(query)}
×
261
                }
640✔
262
                return nil
263
        })
264
}
265

300✔
266
func (p *Postgres) Run(migration io.Reader) error {
310✔
267
        if p.config.MultiStatementEnabled {
10✔
268
                var err error
30✔
269
                if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
20✔
UNCOV
270
                        if err = p.runStatement(m); err != nil {
×
271
                                return false
×
272
                        }
20✔
UNCOV
273
                        return true
×
274
                }); e != nil {
×
275
                        return e
×
276
                }
10✔
277
                return err
278
        }
290✔
279
        migr, err := io.ReadAll(migration)
290✔
UNCOV
280
        if err != nil {
×
281
                return err
×
282
        }
290✔
283
        return p.runStatement(migr)
284
}
285

310✔
286
func (p *Postgres) runStatement(statement []byte) error {
310✔
287
        ctx := context.Background()
310✔
UNCOV
288
        if p.config.StatementTimeout != 0 {
×
289
                var cancel context.CancelFunc
×
290
                ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
×
291
                defer cancel()
×
292
        }
310✔
293
        query := string(statement)
310✔
UNCOV
294
        if strings.TrimSpace(query) == "" {
×
295
                return nil
×
296
        }
320✔
297
        if _, err := p.conn.ExecContext(ctx, query); err != nil {
20✔
298
                if pgErr, ok := err.(*pq.Error); ok {
10✔
299
                        var line uint
10✔
300
                        var col uint
10✔
301
                        var lineColOK bool
20✔
302
                        if pgErr.Position != "" {
20✔
303
                                if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
10✔
304
                                        line, col, lineColOK = computeLineFromPos(query, int(pos))
10✔
305
                                }
306
                        }
10✔
307
                        message := fmt.Sprintf("migration failed: %s", pgErr.Message)
20✔
308
                        if lineColOK {
10✔
309
                                message = fmt.Sprintf("%s (column %d)", message, col)
10✔
310
                        }
10✔
UNCOV
311
                        if pgErr.Detail != "" {
×
312
                                message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
×
313
                        }
10✔
314
                        return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
UNCOV
315
                }
×
316
                return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
317
        }
300✔
318
        return nil
319
}
320

74✔
321
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
74✔
322
        // replace crlf with lf
74✔
323
        s = strings.Replace(s, "\r\n", "\n", -1)
74✔
324
        // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
74✔
325
        runes := []rune(s)
82✔
326
        if pos > len(runes) {
8✔
327
                return 0, 0, false
8✔
328
        }
66✔
329
        sel := runes[:pos]
66✔
330
        line = uint(runesCount(sel, newLine) + 1)
66✔
331
        col = uint(pos - 1 - runesLastIndex(sel, newLine))
66✔
332
        return line, col, true
333
}
334

335
const newLine = '\n'
336

66✔
337
func runesCount(input []rune, target rune) int {
66✔
338
        var count int
1,404✔
339
        for _, r := range input {
1,442✔
340
                if r == target {
104✔
341
                        count++
104✔
342
                }
343
        }
66✔
344
        return count
345
}
346

66✔
347
func runesLastIndex(input []rune, target rune) int {
780✔
348
        for i := len(input) - 1; i >= 0; i-- {
770✔
349
                if input[i] == target {
56✔
350
                        return i
56✔
351
                }
352
        }
10✔
353
        return -1
354
}
355

380✔
356
func (p *Postgres) SetVersion(version int, dirty bool) error {
380✔
357
        tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
380✔
UNCOV
358
        if err != nil {
×
359
                return &database.Error{OrigErr: err, Err: "transaction start failed"}
×
360
        }
361

380✔
362
        query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
380✔
UNCOV
363
        if _, err := tx.Exec(query); err != nil {
×
364
                if errRollback := tx.Rollback(); errRollback != nil {
×
365
                        err = multierror.Append(err, errRollback)
×
366
                }
×
367
                return &database.Error{OrigErr: err, Query: []byte(query)}
368
        }
369

370
        // Also re-write the schema version for nil dirty versions to prevent
371
        // empty schema version for failed down migration on the first migration
372
        // See: https://github.com/golang-migrate/migrate/issues/330
730✔
373
        if version >= 0 || (version == database.NilVersion && dirty) {
350✔
374
                query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
350✔
UNCOV
375
                if _, err := tx.Exec(query, version, dirty); err != nil {
×
376
                        if errRollback := tx.Rollback(); errRollback != nil {
×
377
                                err = multierror.Append(err, errRollback)
×
378
                        }
×
379
                        return &database.Error{OrigErr: err, Query: []byte(query)}
380
                }
381
        }
382

380✔
UNCOV
383
        if err := tx.Commit(); err != nil {
×
384
                return &database.Error{OrigErr: err, Err: "transaction commit failed"}
×
385
        }
386

380✔
387
        return nil
388
}
389

260✔
390
func (p *Postgres) Version() (version int, dirty bool, err error) {
260✔
391
        query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
260✔
392
        err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
260✔
393
        switch {
90✔
394
        case err == sql.ErrNoRows:
90✔
395
                return database.NilVersion, false, nil
UNCOV
396

×
397
        case err != nil:
×
398
                if e, ok := err.(*pq.Error); ok {
×
399
                        if e.Code.Name() == "undefined_table" {
×
400
                                return database.NilVersion, false, nil
×
401
                        }
UNCOV
402
                }
×
403
                return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
404

170✔
405
        default:
170✔
406
                return version, dirty, nil
407
        }
408
}
409

40✔
410
func (p *Postgres) Drop() (err error) {
40✔
411
        // select all tables in current schema
40✔
412
        query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
40✔
413
        tables, err := p.conn.QueryContext(context.Background(), query)
40✔
UNCOV
414
        if err != nil {
×
415
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
416
        }
80✔
417
        defer func() {
40✔
UNCOV
418
                if errClose := tables.Close(); errClose != nil {
×
419
                        err = multierror.Append(err, errClose)
×
420
                }
421
        }()
422

423
        // delete one table after another
40✔
424
        tableNames := make([]string, 0)
110✔
425
        for tables.Next() {
70✔
426
                var tableName string
70✔
UNCOV
427
                if err := tables.Scan(&tableName); err != nil {
×
428
                        return err
×
429
                }
140✔
430
                if len(tableName) > 0 {
70✔
431
                        tableNames = append(tableNames, tableName)
70✔
432
                }
433
        }
40✔
UNCOV
434
        if err := tables.Err(); err != nil {
×
435
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
436
        }
437

80✔
438
        if len(tableNames) > 0 {
40✔
439
                // delete one by one ...
110✔
440
                for _, t := range tableNames {
70✔
441
                        query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
70✔
UNCOV
442
                        if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
×
443
                                return &database.Error{OrigErr: err, Query: []byte(query)}
×
444
                        }
445
                }
446
        }
447

40✔
448
        return nil
449
}
450

451
// ensureVersionTable checks if versions table exists and, if not, creates it.
452
// Note that this function locks the database, which deviates from the usual
453
// convention of "caller locks" in the Postgres type.
520✔
454
func (p *Postgres) ensureVersionTable() (err error) {
520✔
UNCOV
455
        if err = p.Lock(); err != nil {
×
456
                return err
×
457
        }
458

1,040✔
459
        defer func() {
520✔
UNCOV
460
                if e := p.Unlock(); e != nil {
×
461
                        if err == nil {
×
462
                                err = e
×
463
                        } else {
×
464
                                err = multierror.Append(err, e)
×
465
                        }
466
                }
467
        }()
468

469
        // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
470
        // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
471
        // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
472
        // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
520✔
473
        query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
520✔
474
        row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
520✔
475

520✔
476
        var count int
520✔
477
        err = row.Scan(&count)
520✔
UNCOV
478
        if err != nil {
×
479
                return &database.Error{OrigErr: err, Query: []byte(query)}
×
480
        }
481

820✔
482
        if count == 1 {
300✔
483
                return nil
300✔
484
        }
485

220✔
486
        query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
240✔
487
        if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
20✔
488
                return &database.Error{OrigErr: err, Query: []byte(query)}
20✔
489
        }
490

200✔
491
        return nil
492
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc