From ccfb885423604a30b60ddce79439def8a7900e7d Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Fri, 6 Oct 2023 23:53:26 -0400 Subject: [PATCH] feat(experimental): goose provider with unimplemented methods (#596) --- dialect.go | 14 ++ go.mod | 1 + go.sum | 2 + internal/sqladapter/sqladapter.go | 49 +++++++ internal/sqladapter/store.go | 111 ++++++++++++++ internal/sqladapter/store_test.go | 218 ++++++++++++++++++++++++++++ internal/sqlextended/sqlextended.go | 23 +++ internal/sqlparser/parser.go | 8 + provider.go | 196 +++++++++++++++++++++++++ provider_options.go | 50 +++++++ provider_options_test.go | 100 +++++++++++++ 11 files changed, 772 insertions(+) create mode 100644 internal/sqladapter/sqladapter.go create mode 100644 internal/sqladapter/store.go create mode 100644 internal/sqladapter/store_test.go create mode 100644 internal/sqlextended/sqlextended.go create mode 100644 provider.go create mode 100644 provider_options.go create mode 100644 provider_options_test.go diff --git a/dialect.go b/dialect.go index a14248002..83c81c4dd 100644 --- a/dialect.go +++ b/dialect.go @@ -6,6 +6,20 @@ import ( "github.com/pressly/goose/v3/internal/dialect" ) +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + func init() { store, _ = dialect.NewStore(dialect.Postgres) } diff --git a/go.mod b/go.mod index 931d9a985..9beb962cf 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/ory/dockertest/v3 v3.10.0 github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 + go.uber.org/multierr v1.11.0 modernc.org/sqlite v1.26.0 ) diff --git a/go.sum b/go.sum index 0c34e9ce7..59b8881ef 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 000000000..f6c975dc4 --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,49 @@ +// Package sqladapter provides an interface for interacting with a SQL database. +// +// All supported database dialects must implement the Store interface. +package sqladapter + +import ( + "context" + "time" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +// Store is the interface that wraps the basic methods for a database dialect. +// +// A dialect is a set of SQL statements that are specific to a database. +// +// By defining a store interface, we can support multiple databases with a single codebase. +// +// The underlying implementation does not modify the error. It is the callers responsibility to +// assert for the correct error, such as [sql.ErrNoRows]. +type Store interface { + // CreateVersionTable creates the version table within a transaction. This table is used to + // record applied migrations. + CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error + + // InsertOrDelete inserts or deletes a version id from the version table. + InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. + // + // Returns the raw sql error if the query fails. It is the callers responsibility to assert for + // the correct error, such as [sql.ErrNoRows]. + GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id. + // + // If there are no migrations, an empty slice is returned with no error. + ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + IsApplied bool + Timestamp time.Time +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} diff --git a/internal/sqladapter/store.go b/internal/sqladapter/store.go new file mode 100644 index 000000000..0ee90ca49 --- /dev/null +++ b/internal/sqladapter/store.go @@ -0,0 +1,111 @@ +package sqladapter + +import ( + "context" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/dialect/dialectquery" + "github.com/pressly/goose/v3/internal/sqlextended" +) + +var _ Store = (*store)(nil) + +type store struct { + tablename string + querier dialectquery.Querier +} + +// NewStore returns a new [Store] backed by the given dialect. +// +// The dialect must match one of the supported dialects defined in dialect.go. +func NewStore(dialect string, table string) (Store, error) { + if table == "" { + return nil, errors.New("table must not be empty") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + var querier dialectquery.Querier + switch dialect { + case "clickhouse": + querier = &dialectquery.Clickhouse{} + case "mssql": + querier = &dialectquery.Sqlserver{} + case "mysql": + querier = &dialectquery.Mysql{} + case "postgres": + querier = &dialectquery.Postgres{} + case "redshift": + querier = &dialectquery.Redshift{} + case "sqlite3": + querier = &dialectquery.Sqlite3{} + case "tidb": + querier = &dialectquery.Tidb{} + case "vertica": + querier = &dialectquery.Vertica{} + default: + return nil, fmt.Errorf("unknown dialect: %q", dialect) + } + return &store{ + tablename: table, + querier: querier, + }, nil +} + +func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error { + q := s.querier.CreateTable(s.tablename) + if _, err := db.ExecContext(ctx, q); err != nil { + return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) + } + return nil +} + +func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error { + if direction { + q := s.querier.InsertVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version, true); err != nil { + return fmt.Errorf("failed to insert version %d: %w", version, err) + } + return nil + } + q := s.querier.DeleteVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version); err != nil { + return fmt.Errorf("failed to delete version %d: %w", version, err) + } + return nil +} + +func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(s.tablename) + var result GetMigrationResult + if err := db.QueryRowContext(ctx, q, version).Scan( + &result.Timestamp, + &result.IsApplied, + ); err != nil { + return nil, fmt.Errorf("failed to get migration %d: %w", version, err) + } + return &result, nil +} + +func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations(s.tablename) + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf("failed to list migrations: %w", err) + } + defer rows.Close() + + var migrations []*ListMigrationsResult + for rows.Next() { + var result ListMigrationsResult + if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { + return nil, fmt.Errorf("failed to scan list migrations result: %w", err) + } + migrations = append(migrations, &result) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go new file mode 100644 index 000000000..1d0189598 --- /dev/null +++ b/internal/sqladapter/store_test.go @@ -0,0 +1,218 @@ +package sqladapter_test + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/testdb" + "go.uber.org/multierr" + "modernc.org/sqlite" +) + +// The goal of this test is to verify the sqladapter package works as expected. This test is not +// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store +// interface works against a real database. + +func TestStore(t *testing.T) { + t.Parallel() + t.Run("invalid", func(t *testing.T) { + // Test empty table name. + _, err := sqladapter.NewStore("sqlite3", "") + check.HasError(t, err) + // Test unknown dialect. + _, err = sqladapter.NewStore("unknown-dialect", "foo") + check.HasError(t, err) + // Test empty dialect. + _, err = sqladapter.NewStore("", "foo") + check.HasError(t, err) + }) + t.Run("postgres", func(t *testing.T) { + if testing.Short() { + t.Skip("skip long-running test") + } + // Test postgres specific behavior. + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) { + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + check.Bool(t, ok, true) + check.Equal(t, pgErr.Code, "42P07") // duplicate_table + }) + }) + // Test generic behavior. + t.Run("sqlite3", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) { + var sqliteErr *sqlite.Error + ok := errors.As(err, &sqliteErr) + check.Bool(t, ok, true) + check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) + check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") + }) + }) +} + +// testStore tests various store operations. +// +// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable +// when the version table already exists. +func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { + const ( + tablename = "test_goose_db_version" + ) + store, err := sqladapter.NewStore(string(dialect), tablename) + check.NoError(t, err) + // Create the version table. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.NoError(t, err) + // Create the version table again. This should fail. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.HasError(t, err) + if alreadyExists != nil { + alreadyExists(t, err) + } + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Insert 5 migrations in addition to the zero migration. + for i := 0; i < 6; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, true, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 6. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 6) + // Check versions are in descending order. + for i := 0; i < 6; i++ { + check.Number(t, res[i].Version, 5-i) + } + return nil + }) + check.NoError(t, err) + + // Delete 3 migrations backwards + for i := 5; i >= 3; i-- { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 3. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check that the remaining versions are in descending order. + for i := 0; i < 3; i++ { + check.Number(t, res[i].Version, 2-i) + } + return nil + }) + check.NoError(t, err) + + // Get remaining migrations one by one. + for i := 0; i < 3; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.GetMigration(ctx, conn, int64(i)) + check.NoError(t, err) + check.Equal(t, res.IsApplied, true) + check.Equal(t, res.Timestamp.IsZero(), false) + return nil + }) + check.NoError(t, err) + } + + // Delete remaining migrations one by one and use all 3 connection types: + + // 1. *sql.Tx + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.InsertOrDelete(ctx, tx, false, 2) + }) + check.NoError(t, err) + // 2. *sql.Conn + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, 1) + }) + check.NoError(t, err) + // 3. *sql.DB + err = store.InsertOrDelete(ctx, db, false, 0) + check.NoError(t, err) + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Try to get a migration that does not exist. + err = runConn(ctx, db, func(conn *sql.Conn) error { + _, err := store.GetMigration(ctx, conn, 0) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + return nil + }) + check.NoError(t, err) +} + +func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, conn.Close()) + } + }() + if err := fn(conn); err != nil { + return err + } + return conn.Close() +} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go new file mode 100644 index 000000000..e3e763abf --- /dev/null +++ b/internal/sqlextended/sqlextended.go @@ -0,0 +1,23 @@ +package sqlextended + +import ( + "context" + "database/sql" +) + +// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and +// *sql.Conn. +// +// There is a long outstanding issue to formalize a std lib interface, but alas... See: +// https://github.com/golang/go/issues/14468 +type DBTxConn interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +var ( + _ DBTxConn = (*sql.DB)(nil) + _ DBTxConn = (*sql.Tx)(nil) + _ DBTxConn = (*sql.Conn)(nil) +) diff --git a/internal/sqlparser/parser.go b/internal/sqlparser/parser.go index 5e6c67503..a62846026 100644 --- a/internal/sqlparser/parser.go +++ b/internal/sqlparser/parser.go @@ -25,6 +25,14 @@ func FromBool(b bool) Direction { return DirectionDown } +func (d Direction) String() string { + return string(d) +} + +func (d Direction) ToBool() bool { + return d == DirectionUp +} + type parserState int const ( diff --git a/provider.go b/provider.go new file mode 100644 index 000000000..c12d4ea8b --- /dev/null +++ b/provider.go @@ -0,0 +1,196 @@ +package goose + +import ( + "context" + "database/sql" + "errors" + "io/fs" + "time" + + "github.com/pressly/goose/v3/internal/sqladapter" +) + +// NewProvider returns a new goose Provider. +// +// The caller is responsible for matching the database dialect with the database/sql driver. For +// example, if the database dialect is "postgres", the database/sql driver could be +// github.com/lib/pq or github.com/jackc/pgx. +// +// fsys is the filesystem used to read the migration files. Most users will want to use +// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is +// possible to use a different filesystem, such as embed.FS. +// +// Functional options are used to configure the Provider. See [ProviderOption] for more information. +// +// Unless otherwise specified, all methods on Provider are safe for concurrent use. +// +// Experimental: This API is experimental and may change in the future. +func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { + if db == nil { + return nil, errors.New("db must not be nil") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + if fsys == nil { + return nil, errors.New("fsys must not be nil") + } + var cfg config + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + // Set defaults + if cfg.tableName == "" { + cfg.tableName = defaultTablename + } + store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + if err != nil { + return nil, err + } + // TODO(mf): implement the rest of this function - collect sources - merge sources into + // migrations + return &Provider{ + db: db, + fsys: fsys, + cfg: cfg, + store: store, + }, nil +} + +// Provider is a goose migration provider. +// Experimental: This API is experimental and may change in the future. +type Provider struct { + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store +} + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + // State represents the state of the migration. One of "untracked", "pending", "applied". + // - untracked: in the database, but not on the filesystem. + // - pending: on the filesystem, but not in the database. + // - applied: in both the database and on the filesystem. + State string + // AppliedAt is the time the migration was applied. Only set if state is applied or untracked. + AppliedAt time.Time + // Source is the migration source. Only set if the state is pending or applied. + Source Source +} + +// Status returns the status of all migrations, merging the list of migrations from the database and +// filesystem. The returned items are ordered by version, in ascending order. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { + return nil, errors.New("not implemented") +} + +// GetDBVersion returns the max version from the database, regardless of the applied order. For +// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been +// applied, it returns 0. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} + +// SourceType represents the type of migration source. +type SourceType string + +const ( + // SourceTypeSQL represents a SQL migration. + SourceTypeSQL SourceType = "sql" + // SourceTypeGo represents a Go migration. + SourceTypeGo SourceType = "go" +) + +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +// Experimental: This API is experimental and may change in the future. +type Source struct { + // Type is the type of migration. + Type SourceType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +// ListSources returns a list of all available migration sources the provider is aware of. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) ListSources() []*Source { + return nil +} + +// Ping attempts to ping the database to verify a connection is available. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Ping(ctx context.Context) error { + return errors.New("not implemented") +} + +// Close closes the database connection. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Close() error { + return errors.New("not implemented") +} + +// MigrationResult represents the result of a single migration. +type MigrationResult struct{} + +// ApplyVersion applies exactly one migration at the specified version. If there is no source for +// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been +// applied already, this method returns [ErrAlreadyApplied]. +// +// When direction is true, the up migration is executed, and when direction is false, the down +// migration is executed. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// Up applies all pending migrations. If there are no new migrations to apply, this method returns +// empty list and nil error. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// UpByOne applies the next available migration. If there are no migrations to apply, this method +// returns [ErrNoNextVersion]. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// UpTo applies all available migrations up to and including the specified version. If there are no +// migrations to apply, this method returns empty list and nil error. +// +// For instance, if there are three new migrations (9,10,11) and the current database version is 8 +// with a requested version of 10, only versions 9 and 10 will be applied. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// Down rolls back the most recently applied migration. If there are no migrations to apply, this +// method returns [ErrNoNextVersion]. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// DownTo rolls back all migrations down to but not including the specified version. +// +// For instance, if the current database version is 11, and the requested version is 9, only +// migrations 11 and 10 will be rolled back. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} diff --git a/provider_options.go b/provider_options.go new file mode 100644 index 000000000..904b3ed34 --- /dev/null +++ b/provider_options.go @@ -0,0 +1,50 @@ +package goose + +import ( + "errors" + "fmt" +) + +const ( + defaultTablename = "goose_db_version" +) + +// ProviderOption is a configuration option for a goose provider. +type ProviderOption interface { + apply(*config) error +} + +// WithTableName sets the name of the database table used to track history of applied migrations. +// +// If WithTableName is not called, the default value is "goose_db_version". +func WithTableName(name string) ProviderOption { + return configFunc(func(c *config) error { + if c.tableName != "" { + return fmt.Errorf("table already set to %q", c.tableName) + } + if name == "" { + return errors.New("table must not be empty") + } + c.tableName = name + return nil + }) +} + +// WithVerbose enables verbose logging. +func WithVerbose() ProviderOption { + return configFunc(func(c *config) error { + c.verbose = true + return nil + }) +} + +type config struct { + tableName string + verbose bool +} + +type configFunc func(*config) error + +func (o configFunc) apply(cfg *config) error { + return o(cfg) +} diff --git a/provider_options_test.go b/provider_options_test.go new file mode 100644 index 000000000..629c6efaa --- /dev/null +++ b/provider_options_test.go @@ -0,0 +1,100 @@ +package goose_test + +import ( + "database/sql" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" +) + +func TestNewProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + fsys := newFsys() + t.Run("invalid", func(t *testing.T) { + // Empty dialect not allowed + _, err = goose.NewProvider("", db, fsys) + check.HasError(t, err) + // Invalid dialect not allowed + _, err = goose.NewProvider("unknown-dialect", db, fsys) + check.HasError(t, err) + // Nil db not allowed + _, err = goose.NewProvider("sqlite3", nil, fsys) + check.HasError(t, err) + // Nil fsys not allowed + _, err = goose.NewProvider("sqlite3", db, nil) + check.HasError(t, err) + // Duplicate table name not allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo"), goose.WithTableName("bar")) + check.HasError(t, err) + check.Equal(t, `table already set to "foo"`, err.Error()) + // Empty table name not allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("")) + check.HasError(t, err) + check.Equal(t, "table must not be empty", err.Error()) + }) + t.Run("valid", func(t *testing.T) { + // Valid dialect, db, and fsys allowed + _, err = goose.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + // Valid dialect, db, fsys, and table name allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo")) + check.NoError(t, err) + // Valid dialect, db, fsys, and verbose allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithVerbose()) + check.NoError(t, err) + }) +} + +func newFsys() fs.FS { + return fstest.MapFS{ + "1_foo.sql": {Data: []byte(migration1)}, + "2_bar.sql": {Data: []byte(migration2)}, + "3_baz.sql": {Data: []byte(migration3)}, + "4_qux.sql": {Data: []byte(migration4)}, + } +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +)