Skip to content

Commit

Permalink
feat: Expose setter for global Go migration registry (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Oct 28, 2023
1 parent d59dd9f commit 20a99fa
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 18 deletions.
87 changes: 87 additions & 0 deletions globals.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package goose

import (
"errors"
"fmt"
)

var (
registeredGoMigrations = make(map[int64]*Migration)
)

// ResetGlobalMigrations resets the global go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
}

// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the
// same version has already been registered.
//
// Source may be empty, but if it is set, it must be a path with a numeric component that matches
// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, m := range migrations {
// make a copy of the migration so we can modify it without affecting the original.
if err := validGoMigration(&m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user.
registeredGoMigrations[m.Version] = &m
}
return nil
}

func validGoMigration(m *Migration) error {
if m == nil {
return errors.New("must not be nil")
}
if !m.Registered {
return errors.New("must be registered")
}
if m.Type != TypeGo {
return fmt.Errorf("type must be %q", TypeGo)
}
if m.Version < 1 {
return errors.New("version must be greater than zero")
}
if m.Source != "" {
// If the source is set, expect it to be a path with a numeric component that matches the
// version. This field is not intended to be used for descriptive purposes.
version, err := NumericComponent(m.Source)
if err != nil {
return err
}
if version != m.Version {
return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source)
}
}
// It's valid for all of these funcs to be nil. Which means version the go migration but do not
// run anything.
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
// Do not allow legacy functions to be set.
if m.UpFn != nil {
return errors.New("must not specify UpFn")
}
if m.DownFn != nil {
return errors.New("must not specify DownFn")
}
if m.UpFnNoTx != nil {
return errors.New("must not specify UpFnNoTx")
}
if m.DownFnNoTx != nil {
return errors.New("must not specify DownFnNoTx")
}
return nil
}
113 changes: 113 additions & 0 deletions globals_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package goose_test

import (
"context"
"database/sql"
"testing"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
)

func TestGlobalRegister(t *testing.T) {
// Avoid polluting other tests and do not run in parallel.
t.Cleanup(func() {
goose.ResetGlobalMigrations()
})
fnNoTx := func(context.Context, *sql.DB) error { return nil }
fn := func(context.Context, *sql.Tx) error { return nil }

// Success.
err := goose.SetGlobalMigrations(
[]goose.Migration{}...,
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn},
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = goose.SetGlobalMigrations(
goose.Migration{
Registered: true,
Version: 2,
Source: "00002_foo.sql",
Type: goose.TypeGo,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
},
)
check.NoError(t, err)
// Reset.
{
goose.ResetGlobalMigrations()
}
// Failure.
err = goose.SetGlobalMigrations(
goose.Migration{},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must be registered")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`)
// Legacy functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx")
// Context-aware functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext")
// Source and version mismatch.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`)
}
2 changes: 0 additions & 2 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ var (
ErrNoNextVersion = errors.New("no next version found")
// MaxVersion is the maximum allowed version.
MaxVersion int64 = math.MaxInt64

registeredGoMigrations = map[int64]*Migration{}
)

// Migrations slice.
Expand Down
34 changes: 18 additions & 16 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,27 @@ type MigrationRecord struct {
IsApplied bool // was this a result of up() or down()
}

// Migration struct.
// Migration struct represents either a SQL or Go migration.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool

// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
Type MigrationType
Version int64
Source string // path to .sql script or .go file
Registered bool
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext

// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
UseTx bool
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none

// We still save the non-context versions in the struct in case someone is using them. Goose
// does not use these internally anymore in favor of the context-aware versions.
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx

// New functions with context
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
noVersioning bool
noVersioning bool
}

func (m *Migration) String() string {
Expand Down Expand Up @@ -233,7 +235,7 @@ func NumericComponent(filename string) (int64, error) {
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
Expand Down
17 changes: 17 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package goose

// MigrationType is the type of migration.
type MigrationType string

const (
TypeGo MigrationType = "go"
TypeSQL MigrationType = "sql"
)

func (t MigrationType) String() string {
// This should never happen.
if t == "" {
return "unknown migration type"
}
return string(t)
}

0 comments on commit 20a99fa

Please sign in to comment.