Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing: replace check with stretchr/testify #842

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 62 additions & 62 deletions database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"testing"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
"modernc.org/sqlite"
)
Expand All @@ -22,47 +22,47 @@ func TestDialectStore(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
// Test empty table name.
_, err := database.NewStore(database.DialectSQLite3, "")
check.HasError(t, err)
require.Error(t, err)
// Test unknown dialect.
_, err = database.NewStore("unknown-dialect", "foo")
check.HasError(t, err)
require.Error(t, err)
// Test empty dialect.
_, err = database.NewStore("", "foo")
check.HasError(t, err)
require.Error(t, err)
})
// Test generic behavior.
t.Run("sqlite3", func(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
check.NoError(t, err)
require.NoError(t, err)
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
var sqliteErr *sqlite.Error
ok := errors.As(err, &sqliteErr)
check.Bool(t, ok, true)
check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR)
check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
require.True(t, ok)
require.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR)
require.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
})
})
t.Run("ListMigrations", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
require.NoError(t, err)
store, err := database.NewStore(database.DialectSQLite3, "foo")
check.NoError(t, err)
require.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
require.NoError(t, err)
insert := func(db *sql.DB, version int64) error {
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
}
check.NoError(t, insert(db, 1))
check.NoError(t, insert(db, 3))
check.NoError(t, insert(db, 2))
require.NoError(t, insert(db, 1))
require.NoError(t, insert(db, 3))
require.NoError(t, insert(db, 2))
res, err := store.ListMigrations(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(res), 3)
require.NoError(t, err)
require.Equal(t, len(res), 3)
// Check versions are in descending order: [2, 3, 1]
check.Number(t, res[0].Version, 2)
check.Number(t, res[1].Version, 3)
check.Number(t, res[2].Version, 1)
require.EqualValues(t, res[0].Version, 2)
require.EqualValues(t, res[1].Version, 3)
require.EqualValues(t, res[2].Version, 1)
})
}

Expand All @@ -81,142 +81,142 @@ func testStore(
tablename = "test_goose_db_version"
)
store, err := database.NewStore(d, tablename)
check.NoError(t, err)
require.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx)
})
check.NoError(t, err)
require.NoError(t, err)
// Create the version table again. This should fail.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx)
})
check.HasError(t, err)
require.Error(t, err)
if alreadyExists != nil {
alreadyExists(t, err)
}
// Get the latest version. There should be none.
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)
require.ErrorIs(t, err, database.ErrVersionNotFound)

// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 0)
require.NoError(t, err)
require.Equal(t, len(res), 0)
return nil
})
check.NoError(t, err)
require.NoError(t, err)

// Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
check.NoError(t, err)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i))
require.NoError(t, err)
require.Equal(t, latest, int64(i))
return nil
})
check.NoError(t, err)
require.NoError(t, err)
}

// List migrations. There should be 6.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 6)
require.NoError(t, err)
require.Equal(t, len(res), 6)
// Check versions are in descending order.
for i := 0; i < 6; i++ {
check.Number(t, res[i].Version, 5-i)
require.EqualValues(t, res[i].Version, 5-i)
}
return nil
})
check.NoError(t, err)
require.NoError(t, err)

// Delete 3 migrations backwards
for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Delete(ctx, conn, int64(i))
check.NoError(t, err)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i-1))
require.NoError(t, err)
require.Equal(t, latest, int64(i-1))
return nil
})
check.NoError(t, err)
require.NoError(t, err)
}

// List migrations. There should be 3.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 3)
require.NoError(t, err)
require.Equal(t, len(res), 3)
// Check that the remaining versions are in descending order.
for i := 0; i < 3; i++ {
check.Number(t, res[i].Version, 2-i)
require.EqualValues(t, res[i].Version, 2-i)
}
return nil
})
check.NoError(t, err)
require.NoError(t, err)

// Get remaining migrations one by one.
for i := 0; i < 3; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.GetMigration(ctx, conn, int64(i))
check.NoError(t, err)
check.Equal(t, res.IsApplied, true)
check.Equal(t, res.Timestamp.IsZero(), false)
require.NoError(t, err)
require.Equal(t, res.IsApplied, true)
require.Equal(t, res.Timestamp.IsZero(), false)
return nil
})
check.NoError(t, err)
require.NoError(t, err)
}

// Delete remaining migrations one by one and use all 3 connection types:

// 1. *sql.Tx
err = runTx(ctx, db, func(tx *sql.Tx) error {
err := store.Delete(ctx, tx, 2)
check.NoError(t, err)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, tx)
check.NoError(t, err)
check.Number(t, latest, 1)
require.NoError(t, err)
require.EqualValues(t, latest, 1)
return nil
})
check.NoError(t, err)
require.NoError(t, err)
// 2. *sql.Conn
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Delete(ctx, conn, 1)
check.NoError(t, err)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, 0)
require.NoError(t, err)
require.EqualValues(t, latest, 0)
return nil
})
check.NoError(t, err)
require.NoError(t, err)
// 3. *sql.DB
err = store.Delete(ctx, db, 0)
check.NoError(t, err)
require.NoError(t, err)
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)
require.ErrorIs(t, err, database.ErrVersionNotFound)

// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 0)
require.NoError(t, err)
require.Equal(t, len(res), 0)
return nil
})
check.NoError(t, err)
require.NoError(t, err)

// Try to get a migration that does not exist.
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigration(ctx, conn, 0)
check.HasError(t, err)
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
require.Error(t, err)
require.True(t, errors.Is(err, database.ErrVersionNotFound))
return nil
})
check.NoError(t, err)
require.NoError(t, err)
}

func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
Expand Down
Loading