Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQLite #20

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dbump_sqlite/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module github.com/cristalhq/dbump/dbump_sqlite

go 1.17

require (
github.com/cristalhq/dbump v0.14.0
github.com/mattn/go-sqlite3 v1.14.22
)

replace github.com/cristalhq/dbump => ../
2 changes: 2 additions & 0 deletions dbump_sqlite/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
138 changes: 138 additions & 0 deletions dbump_sqlite/sqlite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package dbump_sqlite

import (
"context"
"database/sql"
"errors"
"fmt"
"hash/fnv"

"github.com/cristalhq/dbump"
)

var _ dbump.Migrator = &Migrator{}

// Migrator to migrate Postgres.
type Migrator struct {
conn *sql.DB
cfg Config
}

// Config for the migrator.
type Config struct {
// Schema for the dbump version table. Default is empty which means "public" schema.
Schema string
// Table for the dbump version table. Default is empty which means "_dbump_log" table.
Table string

// [schema.]table
tableName string
// to prevent multiple migrations running at the same time
lockNum int64

_ struct{} // enforce explicit field names.
}

// NewMigrator instantiates new Migrator.
func NewMigrator(conn *sql.DB, cfg Config) *Migrator {
if cfg.Schema != "" {
cfg.Schema += "."
}
if cfg.Table == "" {
cfg.Table = "_dbump_log"
}

cfg.tableName = cfg.Schema + cfg.Table
cfg.lockNum = hashTableName(cfg.tableName)

return &Migrator{
conn: conn,
cfg: cfg,
}
}

// Init is a method from dbump.Migrator interface.
func (s *Migrator) Init(ctx context.Context) error {
query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
id INTEGER PRIMARY KEY,
version INTEGER NOT NULL,
created_at DATETIME NOT NULL
);`, s.cfg.tableName)

_, err := s.conn.ExecContext(ctx, query)
return err
}

// Drop is a method from dbump.Migrator interface.
func (s *Migrator) Drop(ctx context.Context) error {
query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, s.cfg.tableName)

// TODO: probably should ignore error for this query
// if s.cfg.Schema != "" {
// query = fmt.Sprintf(`DROP SCHEMA IF EXISTS %s RESTRICT;`, s.cfg.Schema)
// }
_, err := s.conn.ExecContext(ctx, query)
return err
}

// LockDB is a method from dbump.Migrator interface.
func (s *Migrator) LockDB(ctx context.Context) error {
return nil
}

// UnlockDB is a method from dbump.Migrator interface.
func (s *Migrator) UnlockDB(ctx context.Context) error {
return nil
}

// Version is a method from dbump.Migrator interface.
func (s *Migrator) Version(ctx context.Context) (version int, err error) {
query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", s.cfg.tableName)
row := s.conn.QueryRowContext(ctx, query)
err = row.Scan(&version)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return version, err
}

// DoStep is a method from dbump.Migrator interface.
func (s *Migrator) DoStep(ctx context.Context, step dbump.Step) error {
if step.DisableTx {
if _, err := s.conn.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName)
_, err := s.conn.ExecContext(ctx, query, step.Version)
return err
}
return s.inTx(ctx, step)
}

func (s *Migrator) inTx(ctx context.Context, step dbump.Step) error {
tx, err := s.conn.BeginTx(ctx, nil)
if err != nil {
return err
}

query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName)
_, err = s.conn.ExecContext(ctx, query, step.Version)
// fmt.Printf("args: %s %d\n", query, step.Version)
if err != nil {
_ = tx.Rollback()
return err
}

_, err = tx.ExecContext(ctx, step.Query)
if err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}

func hashTableName(s string) int64 {
h := fnv.New64()
h.Write([]byte(s))
return int64(h.Sum64())
}
180 changes: 180 additions & 0 deletions dbump_sqlite/sqlite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package dbump_sqlite

import (
"context"
"database/sql"
"os"
"reflect"
"testing"

"github.com/cristalhq/dbump"
"github.com/cristalhq/dbump/tests"

_ "github.com/mattn/go-sqlite3"
)

func TestSQLite_Simple(t *testing.T) {
conn := newTestConn(t)

m := NewMigrator(conn, Config{})
l := dbump.NewSliceLoader([]*dbump.Migration{
{
ID: 1,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
{
ID: 2,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
{
ID: 3,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
})

errRun := dbump.Run(context.Background(), dbump.Config{
Migrator: m,
Loader: l,
Mode: dbump.ModeApplyAll,
})
mustOK(t, errRun)
}

func TestNonDefaultSchemaTable(t *testing.T) {
conn := newTestConn(t)

testCases := []struct {
name string
schema string
table string
wantTableName string
wantLockNum int64
}{
{
name: "all empty",
schema: "",
table: "",
wantTableName: "_dbump_log",
wantLockNum: -3987518601082986461,
},
{
name: "schema set",
schema: "test_schema",
table: "",
wantTableName: "test_schema._dbump_log",
wantLockNum: 1417388815471108263,
},
{
name: "table set",
schema: "",
table: "test_table",
wantTableName: "test_table",
wantLockNum: 8712390964734167792,
},
{
name: "schema and table set",
schema: "test_schema",
table: "test_table",
wantTableName: "test_schema.test_table",
wantLockNum: 4631047095544292572,
},
}

for _, tc := range testCases {
m := NewMigrator(conn, Config{
Schema: tc.schema,
Table: tc.table,
})
mustEqual(t, m.cfg.tableName, tc.wantTableName)
mustEqual(t, m.cfg.lockNum, tc.wantLockNum)
}
}

func TestMigrate_ApplyAll(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).ApplyAll(t)
}

func TestMigrate_ApplyOne(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).ApplyOne(t)
}

func TestMigrate_ApplyAllWhenFull(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).ApplyAllWhenFull(t)
}

func TestMigrate_RevertOne(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).RevertOne(t)
}

func TestMigrate_RevertAllWhenEmpty(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).RevertAllWhenEmpty(t)
}

func TestMigrate_RevertAll(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).RevertAll(t)
}

func TestMigrate_Redo(t *testing.T) {
conn := newTestConn(t)
newSuite(conn).Redo(t)
}

func TestMigrate_Drop(t *testing.T) {
conn := newTestConn(t)
suite := newSuite(conn)
suite.SkipCleanup = true
suite.Drop(t)
}

var envPath = os.Getenv("DBUMP_SQLITE_PATH")

func newTestConn(tb testing.TB) *sql.DB {
path := "./db.sqlitedb"
if envPath != "" {
path = envPath
}

// ignore error if it doesn't exist.
_ = os.Remove(path)

conn, err := sql.Open("sqlite3", path)
mustOK(tb, err)

tb.Cleanup(func() {
_ = os.Remove(path)
})
return conn
}

func newSuite(conn *sql.DB) *tests.MigratorSuite {
m := NewMigrator(conn, Config{})
suite := tests.NewMigratorSuite(m)
suite.ApplyTmpl = "CREATE TABLE %[1]s_%[2]d (id INT);"
suite.RevertTmpl = "DROP TABLE %[1]s_%[2]d;"
suite.CleanMigTmpl = "DROP TABLE IF EXISTS %[1]s_%[2]d;"
suite.CleanTest = "DELETE FROM _dbump_log;"
return suite
}

func mustOK(tb testing.TB, err error) {
tb.Helper()
if err != nil {
tb.Fatal(err)
}
}

func mustEqual(tb testing.TB, got, want interface{}) {
tb.Helper()
if !reflect.DeepEqual(got, want) {
tb.Fatalf("\nhave %+v\nwant %+v", got, want)
}
}
8 changes: 7 additions & 1 deletion tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type MigratorSuite struct {
RevertTmpl string
CleanMigTmpl string
CleanTest string
SkipCleanup bool
}

func NewMigratorSuite(m dbump.Migrator) *MigratorSuite {
Expand Down Expand Up @@ -234,10 +235,15 @@ func (suite *MigratorSuite) genMigrations(tb testing.TB, num int, testname strin
}

tb.Cleanup(func() {
if suite.SkipCleanup {
return
}

for i := 1; i <= num; i++ {
query := fmt.Sprintf(suite.CleanMigTmpl, testname, i)
failIfErr(tb, suite.migrator.DoStep(context.Background(), dbump.Step{
Query: query,
Version: num - i - 1,
Query: query,
}))
}
failIfErr(tb, suite.migrator.DoStep(context.Background(), dbump.Step{
Expand Down
Loading