Skip to content

Commit

Permalink
feat(experimental): goose provider with unimplemented methods (#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Oct 7, 2023
1 parent 473a70d commit ccfb885
Show file tree
Hide file tree
Showing 11 changed files with 772 additions and 0 deletions.
14 changes: 14 additions & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ import (
"github.com/pressly/goose/v3/internal/dialect"
)

// Dialect is the type of database dialect.
type Dialect string

const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
)

func init() {
store, _ = dialect.NewStore(dialect.Postgres)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/ory/dockertest/v3 v3.10.0
github.com/vertica/vertica-sql-go v1.3.3
github.com/ziutek/mymysql v1.5.4
go.uber.org/multierr v1.11.0
modernc.org/sqlite v1.26.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs=
go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY=
go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg=
go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
49 changes: 49 additions & 0 deletions internal/sqladapter/sqladapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Package sqladapter provides an interface for interacting with a SQL database.
//
// All supported database dialects must implement the Store interface.
package sqladapter

import (
"context"
"time"

"github.com/pressly/goose/v3/internal/sqlextended"
)

// Store is the interface that wraps the basic methods for a database dialect.
//
// A dialect is a set of SQL statements that are specific to a database.
//
// By defining a store interface, we can support multiple databases with a single codebase.
//
// The underlying implementation does not modify the error. It is the callers responsibility to
// assert for the correct error, such as [sql.ErrNoRows].
type Store interface {
// CreateVersionTable creates the version table within a transaction. This table is used to
// record applied migrations.
CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error

// InsertOrDelete inserts or deletes a version id from the version table.
InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error

// GetMigration retrieves a single migration by version id.
//
// Returns the raw sql error if the query fails. It is the callers responsibility to assert for
// the correct error, such as [sql.ErrNoRows].
GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id.
//
// If there are no migrations, an empty slice is returned with no error.
ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error)
}

type GetMigrationResult struct {
IsApplied bool
Timestamp time.Time
}

type ListMigrationsResult struct {
Version int64
IsApplied bool
}
111 changes: 111 additions & 0 deletions internal/sqladapter/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package sqladapter

import (
"context"
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
"github.com/pressly/goose/v3/internal/sqlextended"
)

var _ Store = (*store)(nil)

type store struct {
tablename string
querier dialectquery.Querier
}

// NewStore returns a new [Store] backed by the given dialect.
//
// The dialect must match one of the supported dialects defined in dialect.go.
func NewStore(dialect string, table string) (Store, error) {
if table == "" {
return nil, errors.New("table must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
var querier dialectquery.Querier
switch dialect {
case "clickhouse":
querier = &dialectquery.Clickhouse{}
case "mssql":
querier = &dialectquery.Sqlserver{}
case "mysql":
querier = &dialectquery.Mysql{}
case "postgres":
querier = &dialectquery.Postgres{}
case "redshift":
querier = &dialectquery.Redshift{}
case "sqlite3":
querier = &dialectquery.Sqlite3{}
case "tidb":
querier = &dialectquery.Tidb{}
case "vertica":
querier = &dialectquery.Vertica{}
default:
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: table,
querier: querier,
}, nil
}

func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
}
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", version, err)
}
return nil
}
q := s.querier.DeleteVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
return nil
}

func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
&result.IsApplied,
); err != nil {
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return &result, nil
}

func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
}
defer rows.Close()

var migrations []*ListMigrationsResult
for rows.Next() {
var result ListMigrationsResult
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
}
migrations = append(migrations, &result)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrations, nil
}
Loading

0 comments on commit ccfb885

Please sign in to comment.