diff --git a/CHANGELOG.md b/CHANGELOG.md index c65501071..445f3dbf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Add `CheckPending` method to `goose.Provider` to check if there are pending migrations, returns the current (max db) version and the latest (max file) version. (#756) +- Clarify `GetLatestVersion` method MUST return `ErrVersionNotFound` if no latest migration is + found. Previously it was returning a -1 and nil error, which was inconsistent with the rest of the + API surface. + +- Add `GetLatestVersion` implementations to all existing dialects. This is an optimization to avoid + loading all migrations when only the latest version is needed. This uses the `max` function in SQL + to get the latest version_id irrespective of the order of applied migrations. + - Refactor existing portions of the code to use the new `GetLatestVersion` method. ## [v3.20.0] diff --git a/database/dialect.go b/database/dialect.go index ca7d24cf2..2ac197d10 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -110,7 +110,15 @@ func (s *store) GetMigration( } func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) { - return -1, errors.New("not implemented") + q := s.querier.GetLatestVersion(s.tablename) + var version sql.NullInt64 + if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil { + return -1, fmt.Errorf("failed to get latest version: %w", err) + } + if !version.Valid { + return -1, fmt.Errorf("latest %w", ErrVersionNotFound) + } + return version.Int64, nil } func (s *store) ListMigrations( diff --git a/database/store.go b/database/store.go index 60ce56ce8..0c7e44de8 100644 --- a/database/store.go +++ b/database/store.go @@ -7,8 +7,12 @@ import ( ) var ( - // ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found. + // ErrVersionNotFound must be returned by [GetMigration] or [GetLatestVersion] when a migration + // does not exist. ErrVersionNotFound = errors.New("version not found") + + // ErrNotImplemented must be returned by methods that are not implemented. + ErrNotImplemented = errors.New("not implemented") ) // Store is an interface that defines methods for tracking and managing migrations. It is used by @@ -34,7 +38,7 @@ type Store interface { // version is not found, this method must return [ErrVersionNotFound]. GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) // GetLatestVersion retrieves the last applied migration version. If no migrations exist, this - // method must return -1 and no error. + // method must return [ErrVersionNotFound]. GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If // there are no migrations, return empty slice with no error. Typically this method will return diff --git a/database/store_test.go b/database/store_test.go index 93143f619..d63f7f818 100644 --- a/database/store_test.go +++ b/database/store_test.go @@ -95,6 +95,9 @@ func testStore( if alreadyExists != nil { alreadyExists(t, err) } + // Get the latest version. There should be none. + _, err = store.GetLatestVersion(ctx, db) + check.IsError(t, err, database.ErrVersionNotFound) // List migrations. There should be none. err = runConn(ctx, db, func(conn *sql.Conn) error { @@ -108,7 +111,12 @@ func testStore( // Insert 5 migrations in addition to the zero migration. for i := 0; i < 6; i++ { err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)}) + err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)}) + check.NoError(t, err) + latest, err := store.GetLatestVersion(ctx, conn) + check.NoError(t, err) + check.Number(t, latest, int64(i)) + return nil }) check.NoError(t, err) } @@ -129,7 +137,12 @@ func testStore( // Delete 3 migrations backwards for i := 5; i >= 3; i-- { err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.Delete(ctx, conn, int64(i)) + err := store.Delete(ctx, conn, int64(i)) + check.NoError(t, err) + latest, err := store.GetLatestVersion(ctx, conn) + check.NoError(t, err) + check.Number(t, latest, int64(i-1)) + return nil }) check.NoError(t, err) } @@ -163,17 +176,29 @@ func testStore( // 1. *sql.Tx err = runTx(ctx, db, func(tx *sql.Tx) error { - return store.Delete(ctx, tx, 2) + err := store.Delete(ctx, tx, 2) + check.NoError(t, err) + latest, err := store.GetLatestVersion(ctx, tx) + check.NoError(t, err) + check.Number(t, latest, 1) + return nil }) check.NoError(t, err) // 2. *sql.Conn err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.Delete(ctx, conn, 1) + err := store.Delete(ctx, conn, 1) + check.NoError(t, err) + latest, err := store.GetLatestVersion(ctx, conn) + check.NoError(t, err) + check.Number(t, latest, 0) + return nil }) check.NoError(t, err) // 3. *sql.DB err = store.Delete(ctx, db, 0) check.NoError(t, err) + _, err = store.GetLatestVersion(ctx, db) + check.IsError(t, err, database.ErrVersionNotFound) // List migrations. There should be none. err = runConn(ctx, db, func(conn *sql.Conn) error { diff --git a/internal/dialect/dialectquery/clickhouse.go b/internal/dialect/dialectquery/clickhouse.go index ca07f8684..723efd4cc 100644 --- a/internal/dialect/dialectquery/clickhouse.go +++ b/internal/dialect/dialectquery/clickhouse.go @@ -37,3 +37,8 @@ func (c *Clickhouse) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC` return fmt.Sprintf(q, tableName) } + +func (c *Clickhouse) GetLatestVersion(tableName string) string { + q := `SELECT max(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index 482771aa1..5e10e46e4 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -1,28 +1,27 @@ package dialectquery -// Querier is the interface that wraps the basic methods to create a dialect -// specific query. +// Querier is the interface that wraps the basic methods to create a dialect specific query. type Querier interface { // CreateTable returns the SQL query string to create the db version table. CreateTable(tableName string) string - // InsertVersion returns the SQL query string to insert a new version into - // the db version table. + // InsertVersion returns the SQL query string to insert a new version into the db version table. InsertVersion(tableName string) string - // DeleteVersion returns the SQL query string to delete a version from - // the db version table. + // DeleteVersion returns the SQL query string to delete a version from the db version table. DeleteVersion(tableName string) string - // GetMigrationByVersion returns the SQL query string to get a single - // migration by version. + // GetMigrationByVersion returns the SQL query string to get a single migration by version. // // The query should return the timestamp and is_applied columns. GetMigrationByVersion(tableName string) string - // ListMigrations returns the SQL query string to list all migrations in - // descending order by id. + // ListMigrations returns the SQL query string to list all migrations in descending order by id. // // The query should return the version_id and is_applied columns. ListMigrations(tableName string) string + + // GetLatestVersion returns the SQL query string to get the last version_id from the db version + // table. Returns a nullable int64 value. + GetLatestVersion(tableName string) string } diff --git a/internal/dialect/dialectquery/mysql.go b/internal/dialect/dialectquery/mysql.go index 25954cbc2..b14ef392b 100644 --- a/internal/dialect/dialectquery/mysql.go +++ b/internal/dialect/dialectquery/mysql.go @@ -36,3 +36,8 @@ func (m *Mysql) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (m *Mysql) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index 5103390f4..0faadf5e4 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -36,3 +36,8 @@ func (p *Postgres) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (p *Postgres) GetLatestVersion(tableName string) string { + q := `SELECT max(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/redshift.go b/internal/dialect/dialectquery/redshift.go index 006a0ca6d..4090394ce 100644 --- a/internal/dialect/dialectquery/redshift.go +++ b/internal/dialect/dialectquery/redshift.go @@ -36,3 +36,8 @@ func (r *Redshift) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (r *Redshift) GetLatestVersion(tableName string) string { + q := `SELECT max(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/sqlite3.go b/internal/dialect/dialectquery/sqlite3.go index 689900a72..1c58a74bb 100644 --- a/internal/dialect/dialectquery/sqlite3.go +++ b/internal/dialect/dialectquery/sqlite3.go @@ -35,3 +35,8 @@ func (s *Sqlite3) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (s *Sqlite3) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/sqlserver.go b/internal/dialect/dialectquery/sqlserver.go index 17a617247..4d172c22d 100644 --- a/internal/dialect/dialectquery/sqlserver.go +++ b/internal/dialect/dialectquery/sqlserver.go @@ -35,3 +35,8 @@ func (s *Sqlserver) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (s *Sqlserver) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/tidb.go b/internal/dialect/dialectquery/tidb.go index 984e60a7a..0549e845e 100644 --- a/internal/dialect/dialectquery/tidb.go +++ b/internal/dialect/dialectquery/tidb.go @@ -36,3 +36,8 @@ func (t *Tidb) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (t *Tidb) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/vertica.go b/internal/dialect/dialectquery/vertica.go index 4964aeaf6..f4702be54 100644 --- a/internal/dialect/dialectquery/vertica.go +++ b/internal/dialect/dialectquery/vertica.go @@ -36,3 +36,8 @@ func (v *Vertica) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (v *Vertica) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/ydb.go b/internal/dialect/dialectquery/ydb.go index 4708373d0..ab5e68e3a 100644 --- a/internal/dialect/dialectquery/ydb.go +++ b/internal/dialect/dialectquery/ydb.go @@ -46,3 +46,8 @@ func (c *Ydb) ListMigrations(tableName string) string { FROM %s ORDER BY __discard_column_tstamp DESC` return fmt.Sprintf(q, tableName) } + +func (c *Ydb) GetLatestVersion(tableName string) string { + q := `SELECT MAX(version_id) FROM %s` + return fmt.Sprintf(q, tableName) +} diff --git a/provider.go b/provider.go index 5099789a1..03ee10f37 100644 --- a/provider.go +++ b/provider.go @@ -7,7 +7,6 @@ import ( "fmt" "io/fs" "math" - "sort" "strconv" "strings" "sync" @@ -167,6 +166,10 @@ func (p *Provider) HasPending(ctx context.Context) (bool, error) { // // Note, this method will not use a SessionLocker if one is configured. This allows callers to check // for pending migrations without blocking or being blocked by other operations. +// +// If out-of-order migrations are enabled this method is not suitable for checking pending +// migrations because it ONLY returns the highest version in the database. Instead, use the +// [HasPending] method. func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) { return p.checkPending(ctx) } @@ -483,30 +486,22 @@ func (p *Provider) checkPending(ctx context.Context) (current, target int64, ret retErr = multierr.Append(retErr, cleanup()) }() + target = p.migrations[len(p.migrations)-1].Version + // If versioning is disabled, we always have pending migrations and the target version is the // last migration. if p.cfg.disableVersioning { - return -1, p.migrations[len(p.migrations)-1].Version, nil + return -1, target, nil } - // optimize(mf): we should only fetch the max version from the database, no need to fetch all - // migrations only to get the max version when we're not using out-of-order migrations. - res, err := p.store.ListMigrations(ctx, conn) + + current, err = p.store.GetLatestVersion(ctx, conn) if err != nil { - return -1, -1, err - } - dbVersions := make([]int64, 0, len(res)) - for _, m := range res { - dbVersions = append(dbVersions, m.Version) - } - sort.Slice(dbVersions, func(i, j int) bool { - return dbVersions[i] < dbVersions[j] - }) - if len(dbVersions) == 0 { - return -1, -1, errMissingZeroVersion - } else { - current = dbVersions[len(dbVersions)-1] + if errors.Is(err, database.ErrVersionNotFound) { + return -1, target, errMissingZeroVersion + } + return -1, target, err } - return current, p.migrations[len(p.migrations)-1].Version, nil + return current, target, nil } func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { @@ -523,7 +518,8 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { return true, nil } if p.cfg.allowMissing { - // List all migrations from the database. + // List all migrations from the database. We cannot optimize this because we need to check + // that EVERY migration known the provider has been applied. dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { return false, err @@ -544,16 +540,16 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { } return false, nil } - // If out-of-order migrations are not allowed, we can optimize this by only checking whether the - // last migration the provider knows about is applied. - last := p.migrations[len(p.migrations)-1] - if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil { + // If out-of-order migrations are not allowed, we can optimize this by only checking the latest + // version in the database against the latest migration version. + current, err := p.store.GetLatestVersion(ctx, conn) + if err != nil { if errors.Is(err, database.ErrVersionNotFound) { - return true, nil + return false, errMissingZeroVersion } return false, err } - return false, nil + return current < p.migrations[len(p.migrations)-1].Version, nil } func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { @@ -591,9 +587,6 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err // getDBMaxVersion returns the highest version recorded in the database, regardless of the order in // which migrations were applied. conn may be nil, in which case a connection is initialized. -// -// optimize(mf): we should only fetch the max version from the database, no need to fetch all -// migrations only to get the max version. This means expanding the Store interface. func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) { if conn == nil { var cleanup func() error @@ -606,17 +599,13 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64 retErr = multierr.Append(retErr, cleanup()) }() } - res, err := p.store.ListMigrations(ctx, conn) + + latest, err := p.store.GetLatestVersion(ctx, conn) if err != nil { - return 0, err + if errors.Is(err, database.ErrVersionNotFound) { + return 0, errMissingZeroVersion + } + return -1, err } - if len(res) == 0 { - return 0, errMissingZeroVersion - } - // Sort in descending order. - sort.Slice(res, func(i, j int) bool { - return res[i].Version > res[j].Version - }) - // Return the highest version. - return res[0].Version, nil + return latest, nil }