Skip to content

Commit

Permalink
Implement GetLatestVersion for all natively supported dialects (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Apr 26, 2024
1 parent 2d33f01 commit 272603b
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 57 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

- Add `CheckPending` method to `goose.Provider` to check if there are pending migrations, returns
the current (max db) version and the latest (max file) version. (#756)
- Clarify `GetLatestVersion` method MUST return `ErrVersionNotFound` if no latest migration is
found. Previously it was returning a -1 and nil error, which was inconsistent with the rest of the
API surface.

- Add `GetLatestVersion` implementations to all existing dialects. This is an optimization to avoid
loading all migrations when only the latest version is needed. This uses the `max` function in SQL
to get the latest version_id irrespective of the order of applied migrations.
- Refactor existing portions of the code to use the new `GetLatestVersion` method.

## [v3.20.0]

Expand Down
10 changes: 9 additions & 1 deletion database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ func (s *store) GetMigration(
}

func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
return -1, errors.New("not implemented")
q := s.querier.GetLatestVersion(s.tablename)
var version sql.NullInt64
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
return -1, fmt.Errorf("failed to get latest version: %w", err)
}
if !version.Valid {
return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
}
return version.Int64, nil
}

func (s *store) ListMigrations(
Expand Down
8 changes: 6 additions & 2 deletions database/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import (
)

var (
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
// ErrVersionNotFound must be returned by [GetMigration] or [GetLatestVersion] when a migration
// does not exist.
ErrVersionNotFound = errors.New("version not found")

// ErrNotImplemented must be returned by methods that are not implemented.
ErrNotImplemented = errors.New("not implemented")
)

// Store is an interface that defines methods for tracking and managing migrations. It is used by
Expand All @@ -34,7 +38,7 @@ type Store interface {
// version is not found, this method must return [ErrVersionNotFound].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
// GetLatestVersion retrieves the last applied migration version. If no migrations exist, this
// method must return -1 and no error.
// method must return [ErrVersionNotFound].
GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error)
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error. Typically this method will return
Expand Down
33 changes: 29 additions & 4 deletions database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func testStore(
if alreadyExists != nil {
alreadyExists(t, err)
}
// Get the latest version. There should be none.
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)

// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
Expand All @@ -108,7 +111,12 @@ func testStore(
// Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i))
return nil
})
check.NoError(t, err)
}
Expand All @@ -129,7 +137,12 @@ func testStore(
// Delete 3 migrations backwards
for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Delete(ctx, conn, int64(i))
err := store.Delete(ctx, conn, int64(i))
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i-1))
return nil
})
check.NoError(t, err)
}
Expand Down Expand Up @@ -163,17 +176,29 @@ func testStore(

// 1. *sql.Tx
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.Delete(ctx, tx, 2)
err := store.Delete(ctx, tx, 2)
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, tx)
check.NoError(t, err)
check.Number(t, latest, 1)
return nil
})
check.NoError(t, err)
// 2. *sql.Conn
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Delete(ctx, conn, 1)
err := store.Delete(ctx, conn, 1)
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, 0)
return nil
})
check.NoError(t, err)
// 3. *sql.DB
err = store.Delete(ctx, db, 0)
check.NoError(t, err)
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)

// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
Expand Down
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ func (c *Clickhouse) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
19 changes: 9 additions & 10 deletions internal/dialect/dialectquery/dialectquery.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
package dialectquery

// Querier is the interface that wraps the basic methods to create a dialect
// specific query.
// Querier is the interface that wraps the basic methods to create a dialect specific query.
type Querier interface {
// CreateTable returns the SQL query string to create the db version table.
CreateTable(tableName string) string

// InsertVersion returns the SQL query string to insert a new version into
// the db version table.
// InsertVersion returns the SQL query string to insert a new version into the db version table.
InsertVersion(tableName string) string

// DeleteVersion returns the SQL query string to delete a version from
// the db version table.
// DeleteVersion returns the SQL query string to delete a version from the db version table.
DeleteVersion(tableName string) string

// GetMigrationByVersion returns the SQL query string to get a single
// migration by version.
// GetMigrationByVersion returns the SQL query string to get a single migration by version.
//
// The query should return the timestamp and is_applied columns.
GetMigrationByVersion(tableName string) string

// ListMigrations returns the SQL query string to list all migrations in
// descending order by id.
// ListMigrations returns the SQL query string to list all migrations in descending order by id.
//
// The query should return the version_id and is_applied columns.
ListMigrations(tableName string) string

// GetLatestVersion returns the SQL query string to get the last version_id from the db version
// table. Returns a nullable int64 value.
GetLatestVersion(tableName string) string
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func (m *Mysql) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func (p *Postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func (r *Redshift) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ func (s *Sqlite3) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ func (s *Sqlserver) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlserver) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func (t *Tidb) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (t *Tidb) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/vertica.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ func (v *Vertica) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (v *Vertica) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
5 changes: 5 additions & 0 deletions internal/dialect/dialectquery/ydb.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ func (c *Ydb) ListMigrations(tableName string) string {
FROM %s ORDER BY __discard_column_tstamp DESC`
return fmt.Sprintf(q, tableName)
}

func (c *Ydb) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
69 changes: 29 additions & 40 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io/fs"
"math"
"sort"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -167,6 +166,10 @@ func (p *Provider) HasPending(ctx context.Context) (bool, error) {
//
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations.
//
// If out-of-order migrations are enabled this method is not suitable for checking pending
// migrations because it ONLY returns the highest version in the database. Instead, use the
// [HasPending] method.
func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) {
return p.checkPending(ctx)
}
Expand Down Expand Up @@ -483,30 +486,22 @@ func (p *Provider) checkPending(ctx context.Context) (current, target int64, ret
retErr = multierr.Append(retErr, cleanup())
}()

target = p.migrations[len(p.migrations)-1].Version

// If versioning is disabled, we always have pending migrations and the target version is the
// last migration.
if p.cfg.disableVersioning {
return -1, p.migrations[len(p.migrations)-1].Version, nil
return -1, target, nil
}
// optimize(mf): we should only fetch the max version from the database, no need to fetch all
// migrations only to get the max version when we're not using out-of-order migrations.
res, err := p.store.ListMigrations(ctx, conn)

current, err = p.store.GetLatestVersion(ctx, conn)
if err != nil {
return -1, -1, err
}
dbVersions := make([]int64, 0, len(res))
for _, m := range res {
dbVersions = append(dbVersions, m.Version)
}
sort.Slice(dbVersions, func(i, j int) bool {
return dbVersions[i] < dbVersions[j]
})
if len(dbVersions) == 0 {
return -1, -1, errMissingZeroVersion
} else {
current = dbVersions[len(dbVersions)-1]
if errors.Is(err, database.ErrVersionNotFound) {
return -1, target, errMissingZeroVersion
}
return -1, target, err
}
return current, p.migrations[len(p.migrations)-1].Version, nil
return current, target, nil
}

func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
Expand All @@ -523,7 +518,8 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
return true, nil
}
if p.cfg.allowMissing {
// List all migrations from the database.
// List all migrations from the database. We cannot optimize this because we need to check
// that EVERY migration known the provider has been applied.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return false, err
Expand All @@ -544,16 +540,16 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
}
return false, nil
}
// If out-of-order migrations are not allowed, we can optimize this by only checking whether the
// last migration the provider knows about is applied.
last := p.migrations[len(p.migrations)-1]
if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil {
// If out-of-order migrations are not allowed, we can optimize this by only checking the latest
// version in the database against the latest migration version.
current, err := p.store.GetLatestVersion(ctx, conn)
if err != nil {
if errors.Is(err, database.ErrVersionNotFound) {
return true, nil
return false, errMissingZeroVersion
}
return false, err
}
return false, nil
return current < p.migrations[len(p.migrations)-1].Version, nil
}

func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
Expand Down Expand Up @@ -591,9 +587,6 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err

// getDBMaxVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. conn may be nil, in which case a connection is initialized.
//
// optimize(mf): we should only fetch the max version from the database, no need to fetch all
// migrations only to get the max version. This means expanding the Store interface.
func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) {
if conn == nil {
var cleanup func() error
Expand All @@ -606,17 +599,13 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64
retErr = multierr.Append(retErr, cleanup())
}()
}
res, err := p.store.ListMigrations(ctx, conn)

latest, err := p.store.GetLatestVersion(ctx, conn)
if err != nil {
return 0, err
if errors.Is(err, database.ErrVersionNotFound) {
return 0, errMissingZeroVersion
}
return -1, err
}
if len(res) == 0 {
return 0, errMissingZeroVersion
}
// Sort in descending order.
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
// Return the highest version.
return res[0].Version, nil
return latest, nil
}

0 comments on commit 272603b

Please sign in to comment.