• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

golang-migrate / migrate / 13163293181

05 Feb 2025 05:37PM UTC coverage: 56.499% (+0.2%) from 56.319%
13163293181

Pull #1197

github

jtwatson
Update memefish
Pull Request #1197: feature(spanner): Implement DML Support for Spanner

64 of 92 new or added lines in 1 file covered. (69.57%)

1 existing line in 1 file now uncovered.

4612 of 8163 relevant lines covered (56.5%)

52.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.73
/database/spanner/spanner.go
1
package spanner
2

3
import (
4
        "context"
5
        "errors"
6
        "fmt"
7
        "io"
8
        "log"
9
        nurl "net/url"
10
        "regexp"
11
        "strconv"
12
        "strings"
13

14
        "cloud.google.com/go/spanner"
15
        sdb "cloud.google.com/go/spanner/admin/database/apiv1"
16

17
        "github.com/cloudspannerecosystem/memefish"
18
        "github.com/cloudspannerecosystem/memefish/token"
19
        "github.com/golang-migrate/migrate/v4"
20
        "github.com/golang-migrate/migrate/v4/database"
21

22
        adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
23
        "github.com/hashicorp/go-multierror"
24
        uatomic "go.uber.org/atomic"
25
        "google.golang.org/api/iterator"
26
)
27

28
func init() {
2✔
29
        db := Spanner{}
2✔
30
        database.Register("spanner", &db)
2✔
31
}
2✔
32

33
// DefaultMigrationsTable is used if no custom table is specified
34
const DefaultMigrationsTable = "SchemaMigrations"
35

36
const (
37
        unlockedVal = 0
38
        lockedVal   = 1
39
)
40

41
// Driver errors
42
var (
43
        ErrNilConfig      = errors.New("no config")
44
        ErrNoDatabaseName = errors.New("no database name")
45
        ErrNoSchema       = errors.New("no schema")
46
        ErrDatabaseDirty  = errors.New("database is dirty")
47
        ErrLockHeld       = errors.New("unable to obtain lock")
48
        ErrLockNotHeld    = errors.New("unable to release already released lock")
49
)
50

51
// Config used for a Spanner instance
52
type Config struct {
53
        MigrationsTable string
54
        DatabaseName    string
55
        // Whether to parse the migration DDL with spansql before
56
        // running them towards Spanner.
57
        // Parsing outputs clean DDL statements such as reformatted
58
        // and void of comments.
59
        CleanStatements bool
60
}
61

62
// Spanner implements database.Driver for Google Cloud Spanner
63
type Spanner struct {
64
        db     *DB
65
        config *Config
66
        lock   *uatomic.Uint32
67
}
68

69
type DB struct {
70
        admin  *sdb.DatabaseAdminClient
71
        data   *spanner.Client
72
        shared bool
73
}
74

75
func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
×
76
        return &DB{
×
NEW
77
                admin:  &admin,
×
NEW
78
                data:   &data,
×
NEW
79
                shared: true,
×
80
        }
×
81
}
×
82

83
// WithInstance implements database.Driver
84
func WithInstance(instance *DB, config *Config) (database.Driver, error) {
8✔
85
        if config == nil {
8✔
86
                return nil, ErrNilConfig
×
87
        }
×
88

89
        if len(config.DatabaseName) == 0 {
8✔
90
                return nil, ErrNoDatabaseName
×
91
        }
×
92

93
        if len(config.MigrationsTable) == 0 {
16✔
94
                config.MigrationsTable = DefaultMigrationsTable
8✔
95
        }
8✔
96

97
        sx := &Spanner{
8✔
98
                db:     instance,
8✔
99
                config: config,
8✔
100
                lock:   uatomic.NewUint32(unlockedVal),
8✔
101
        }
8✔
102

8✔
103
        if err := sx.ensureVersionTable(); err != nil {
8✔
104
                return nil, err
×
105
        }
×
106

107
        return sx, nil
8✔
108
}
109

110
// Open implements database.Driver
111
func (s *Spanner) Open(url string) (database.Driver, error) {
8✔
112
        purl, err := nurl.Parse(url)
8✔
113
        if err != nil {
8✔
114
                return nil, err
×
115
        }
×
116

117
        ctx := context.Background()
8✔
118

8✔
119
        adminClient, err := sdb.NewDatabaseAdminClient(ctx)
8✔
120
        if err != nil {
8✔
121
                return nil, err
×
122
        }
×
123
        dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
8✔
124
        dataClient, err := spanner.NewClient(ctx, dbname)
8✔
125
        if err != nil {
8✔
126
                log.Fatal(err)
×
127
        }
×
128

129
        migrationsTable := purl.Query().Get("x-migrations-table")
8✔
130

8✔
131
        cleanQuery := purl.Query().Get("x-clean-statements")
8✔
132
        clean := false
8✔
133
        if cleanQuery != "" {
12✔
134
                clean, err = strconv.ParseBool(cleanQuery)
4✔
135
                if err != nil {
4✔
136
                        return nil, err
×
137
                }
×
138
        }
139

140
        db := &DB{admin: adminClient, data: dataClient}
8✔
141
        return WithInstance(db, &Config{
8✔
142
                DatabaseName:    dbname,
8✔
143
                MigrationsTable: migrationsTable,
8✔
144
                CleanStatements: clean,
8✔
145
        })
8✔
146
}
147

148
// Close implements database.Driver
149
func (s *Spanner) Close() error {
×
NEW
150
        if s.db.shared {
×
NEW
151
                return nil
×
NEW
152
        }
×
153
        s.db.data.Close()
×
154
        return s.db.admin.Close()
×
155
}
156

157
// Lock implements database.Driver but doesn't do anything because Spanner only
158
// enqueues the UpdateDatabaseDdlRequest.
159
func (s *Spanner) Lock() error {
28✔
160
        if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
52✔
161
                return nil
24✔
162
        }
24✔
163
        return ErrLockHeld
4✔
164
}
165

166
// Unlock implements database.Driver but no action required, see Lock.
167
func (s *Spanner) Unlock() error {
24✔
168
        if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
48✔
169
                return nil
24✔
170
        }
24✔
171
        return ErrLockNotHeld
×
172
}
173

174
// Run implements database.Driver
175
func (s *Spanner) Run(migration io.Reader) error {
18✔
176
        migr, err := io.ReadAll(migration)
18✔
177
        if err != nil {
18✔
178
                return err
×
179
        }
×
180

181
        ctx := context.Background()
18✔
182

18✔
183
        if !s.config.CleanStatements {
30✔
184
                return s.runDdl(ctx, []string{string(migr)})
12✔
185
        }
12✔
186

187
        stmtGroups, err := statementGroups(migr)
6✔
188
        if err != nil {
6✔
NEW
189
                return err
×
NEW
190
        }
×
191

192
        for _, group := range stmtGroups {
16✔
193
                switch group.typ {
10✔
194
                case statementTypeDDL:
6✔
195
                        if err := s.runDdl(ctx, group.stmts); err != nil {
6✔
NEW
196
                                return err
×
NEW
197
                        }
×
198
                case statementTypeDML:
4✔
199
                        if err := s.runDml(ctx, group.stmts); err != nil {
4✔
NEW
200
                                return err
×
NEW
201
                        }
×
NEW
202
                default:
×
NEW
203
                        return fmt.Errorf("unknown statement type: %s", group.typ)
×
204
                }
205
        }
206

207
        return nil
6✔
208
}
209

210
func (s *Spanner) runDdl(ctx context.Context, stmts []string) error {
18✔
211
        op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
18✔
212
                Database:   s.config.DatabaseName,
18✔
213
                Statements: stmts,
18✔
214
        })
18✔
215

18✔
216
        if err != nil {
18✔
NEW
217
                return &database.Error{OrigErr: err, Err: "migration failed", Query: []byte(strings.Join(stmts, ";\n"))}
×
218
        }
×
219

220
        if err := op.Wait(ctx); err != nil {
18✔
NEW
221
                return &database.Error{OrigErr: err, Err: "migration failed", Query: []byte(strings.Join(stmts, ";\n"))}
×
NEW
222
        }
×
223

224
        return nil
18✔
225
}
226

227
func (s *Spanner) runDml(ctx context.Context, stmts []string) error {
4✔
228
        _, err := s.db.data.ReadWriteTransaction(ctx,
4✔
229
                func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
8✔
230
                        for _, s := range stmts {
8✔
231
                                _, err := txn.Update(ctx, spanner.Statement{SQL: s})
4✔
232
                                if err != nil {
4✔
NEW
233
                                        return err
×
NEW
234
                                }
×
235
                        }
236
                        return nil
4✔
237
                })
238
        if err != nil {
4✔
NEW
239
                return &database.Error{OrigErr: err, Err: "migration failed", Query: []byte(strings.Join(stmts, ";\n"))}
×
UNCOV
240
        }
×
241

242
        return nil
4✔
243
}
244

245
// SetVersion implements database.Driver
246
func (s *Spanner) SetVersion(version int, dirty bool) error {
52✔
247
        ctx := context.Background()
52✔
248

52✔
249
        _, err := s.db.data.ReadWriteTransaction(ctx,
52✔
250
                func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
104✔
251
                        m := []*spanner.Mutation{
52✔
252
                                spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
52✔
253
                                spanner.Insert(s.config.MigrationsTable,
52✔
254
                                        []string{"Version", "Dirty"},
52✔
255
                                        []interface{}{version, dirty},
52✔
256
                                )}
52✔
257
                        return txn.BufferWrite(m)
52✔
258
                })
52✔
259
        if err != nil {
52✔
260
                return &database.Error{OrigErr: err}
×
261
        }
×
262

263
        return nil
52✔
264
}
265

266
// Version implements database.Driver
267
func (s *Spanner) Version() (version int, dirty bool, err error) {
32✔
268
        ctx := context.Background()
32✔
269

32✔
270
        stmt := spanner.Statement{
32✔
271
                SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
32✔
272
        }
32✔
273
        iter := s.db.data.Single().Query(ctx, stmt)
32✔
274
        defer iter.Stop()
32✔
275

32✔
276
        row, err := iter.Next()
32✔
277
        switch err {
32✔
278
        case iterator.Done:
8✔
279
                return database.NilVersion, false, nil
8✔
280
        case nil:
24✔
281
                var v int64
24✔
282
                if err = row.Columns(&v, &dirty); err != nil {
24✔
283
                        return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
×
284
                }
×
285
                version = int(v)
24✔
286
        default:
×
287
                return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
×
288
        }
289

290
        return version, dirty, nil
24✔
291
}
292

293
var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
294

295
// Drop implements database.Driver. Retrieves the database schema first and
296
// creates statements to drop the indexes and tables accordingly.
297
// Note: The drop statements are created in reverse order to how they're
298
// provided in the schema. Assuming the schema describes how the database can
299
// be "build up", it seems logical to "unbuild" the database simply by going the
300
// opposite direction. More testing
301
func (s *Spanner) Drop() error {
8✔
302
        ctx := context.Background()
8✔
303
        res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
8✔
304
                Database: s.config.DatabaseName,
8✔
305
        })
8✔
306
        if err != nil {
8✔
307
                return &database.Error{OrigErr: err, Err: "drop failed"}
×
308
        }
×
309
        if len(res.Statements) == 0 {
8✔
310
                return nil
×
311
        }
×
312

313
        stmts := make([]string, 0)
8✔
314
        for i := len(res.Statements) - 1; i >= 0; i-- {
28✔
315
                s := res.Statements[i]
20✔
316
                m := nameMatcher.FindSubmatch([]byte(s))
20✔
317

20✔
318
                if len(m) == 0 {
20✔
319
                        continue
×
320
                } else if tbl := m[2]; len(tbl) > 0 {
40✔
321
                        stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
20✔
322
                } else if idx := m[4]; len(idx) > 0 {
20✔
323
                        stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
×
324
                }
×
325
        }
326

327
        op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
8✔
328
                Database:   s.config.DatabaseName,
8✔
329
                Statements: stmts,
8✔
330
        })
8✔
331
        if err != nil {
8✔
332
                return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
×
333
        }
×
334
        if err := op.Wait(ctx); err != nil {
8✔
335
                return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
×
336
        }
×
337

338
        return nil
8✔
339
}
340

341
// ensureVersionTable checks if versions table exists and, if not, creates it.
342
// Note that this function locks the database, which deviates from the usual
343
// convention of "caller locks" in the Spanner type.
344
func (s *Spanner) ensureVersionTable() (err error) {
8✔
345
        if err = s.Lock(); err != nil {
8✔
346
                return err
×
347
        }
×
348

349
        defer func() {
16✔
350
                if e := s.Unlock(); e != nil {
8✔
351
                        if err == nil {
×
352
                                err = e
×
353
                        } else {
×
354
                                err = multierror.Append(err, e)
×
355
                        }
×
356
                }
357
        }()
358

359
        ctx := context.Background()
8✔
360
        tbl := s.config.MigrationsTable
8✔
361
        iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
8✔
362
        if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
8✔
363
                return nil
×
364
        }
×
365

366
        stmt := fmt.Sprintf(`CREATE TABLE %s (
8✔
367
    Version INT64 NOT NULL,
8✔
368
    Dirty    BOOL NOT NULL
8✔
369
        ) PRIMARY KEY(Version)`, tbl)
8✔
370

8✔
371
        op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
8✔
372
                Database:   s.config.DatabaseName,
8✔
373
                Statements: []string{stmt},
8✔
374
        })
8✔
375

8✔
376
        if err != nil {
8✔
377
                return &database.Error{OrigErr: err, Query: []byte(stmt)}
×
378
        }
×
379
        if err := op.Wait(ctx); err != nil {
8✔
380
                return &database.Error{OrigErr: err, Query: []byte(stmt)}
×
381
        }
×
382

383
        return nil
8✔
384
}
385

386
type statementType string
387

388
const (
389
        statementTypeUnknown statementType = ""
390
        statementTypeDDL     statementType = "DDL"
391
        statementTypeDML     statementType = "DML"
392
)
393

394
type statementGroup struct {
395
        typ   statementType
396
        stmts []string
397
}
398

399
func statementGroups(migr []byte) (groups []*statementGroup, err error) {
40✔
400
        lex := &memefish.Lexer{
40✔
401
                File: &token.File{Buffer: string(migr)},
40✔
402
        }
40✔
403

40✔
404
        group := &statementGroup{}
40✔
405
        var stmtTyp statementType
40✔
406
        var stmt strings.Builder
40✔
407
        for {
1,016✔
408
                if err := lex.NextToken(); err != nil {
976✔
NEW
409
                        return nil, err
×
NEW
410
                }
×
411

412
                if stmtTyp == statementTypeUnknown {
1,064✔
413
                        switch {
88✔
414
                        case lex.Token.IsKeywordLike("INSERT") || lex.Token.IsKeywordLike("DELETE") || lex.Token.IsKeywordLike("UPDATE"):
8✔
415
                                stmtTyp = statementTypeDML
8✔
416
                        default:
80✔
417
                                stmtTyp = statementTypeDDL
80✔
418
                        }
419
                        if group.typ != stmtTyp {
142✔
420
                                if len(group.stmts) > 0 {
68✔
421
                                        groups = append(groups, group)
14✔
422
                                }
14✔
423
                                group = &statementGroup{typ: stmtTyp}
54✔
424
                        }
425
                }
426

427
                if lex.Token.Kind == token.TokenEOF || lex.Token.Kind == ";" {
1,064✔
428
                        if stmt.Len() > 0 {
148✔
429
                                group.stmts = append(group.stmts, stmt.String())
60✔
430
                        }
60✔
431
                        stmtTyp = statementTypeUnknown
88✔
432
                        stmt.Reset()
88✔
433

88✔
434
                        if lex.Token.Kind == token.TokenEOF {
128✔
435
                                if len(group.stmts) > 0 {
70✔
436
                                        groups = append(groups, group)
30✔
437
                                }
30✔
438

439
                                break
40✔
440
                        }
441

442
                        continue
48✔
443
                }
444

445
                if len(lex.Token.Comments) > 0 && strings.HasPrefix(lex.Token.Comments[0].Raw, "--") {
896✔
446
                        // standard comment Token consumes a \n, so we need to add it back
8✔
447
                        if _, err := stmt.WriteString("\n"); err != nil {
8✔
NEW
448
                                return nil, err
×
NEW
449
                        }
×
450
                }
451
                if stmt.Len() > 0 {
1,720✔
452
                        if _, err := stmt.WriteString(lex.Token.Space); err != nil {
832✔
NEW
453
                                return nil, err
×
NEW
454
                        }
×
455
                }
456
                if _, err := stmt.WriteString(lex.Token.Raw); err != nil {
888✔
NEW
457
                        return nil, err
×
NEW
458
                }
×
459
        }
460

461
        return groups, nil
40✔
462
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc