diff --git a/dbump_sqlite/go.mod b/dbump_sqlite/go.mod new file mode 100644 index 0000000..26919a5 --- /dev/null +++ b/dbump_sqlite/go.mod @@ -0,0 +1,10 @@ +module github.com/cristalhq/dbump/dbump_sqlite + +go 1.17 + +require ( + github.com/cristalhq/dbump v0.14.0 + github.com/mattn/go-sqlite3 v1.14.22 +) + +replace github.com/cristalhq/dbump => ../ diff --git a/dbump_sqlite/go.sum b/dbump_sqlite/go.sum new file mode 100644 index 0000000..e8d092a --- /dev/null +++ b/dbump_sqlite/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/dbump_sqlite/sqlite.go b/dbump_sqlite/sqlite.go new file mode 100644 index 0000000..150fe84 --- /dev/null +++ b/dbump_sqlite/sqlite.go @@ -0,0 +1,138 @@ +package dbump_sqlite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "hash/fnv" + + "github.com/cristalhq/dbump" +) + +var _ dbump.Migrator = &Migrator{} + +// Migrator to migrate Postgres. +type Migrator struct { + conn *sql.DB + cfg Config +} + +// Config for the migrator. +type Config struct { + // Schema for the dbump version table. Default is empty which means "public" schema. + Schema string + // Table for the dbump version table. Default is empty which means "_dbump_log" table. + Table string + + // [schema.]table + tableName string + // to prevent multiple migrations running at the same time + lockNum int64 + + _ struct{} // enforce explicit field names. +} + +// NewMigrator instantiates new Migrator. +func NewMigrator(conn *sql.DB, cfg Config) *Migrator { + if cfg.Schema != "" { + cfg.Schema += "." + } + if cfg.Table == "" { + cfg.Table = "_dbump_log" + } + + cfg.tableName = cfg.Schema + cfg.Table + cfg.lockNum = hashTableName(cfg.tableName) + + return &Migrator{ + conn: conn, + cfg: cfg, + } +} + +// Init is a method from dbump.Migrator interface. +func (s *Migrator) Init(ctx context.Context) error { + query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY, + version INTEGER NOT NULL, + created_at DATETIME NOT NULL +);`, s.cfg.tableName) + + _, err := s.conn.ExecContext(ctx, query) + return err +} + +// Drop is a method from dbump.Migrator interface. +func (s *Migrator) Drop(ctx context.Context) error { + query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, s.cfg.tableName) + + // TODO: probably should ignore error for this query + // if s.cfg.Schema != "" { + // query = fmt.Sprintf(`DROP SCHEMA IF EXISTS %s RESTRICT;`, s.cfg.Schema) + // } + _, err := s.conn.ExecContext(ctx, query) + return err +} + +// LockDB is a method from dbump.Migrator interface. +func (s *Migrator) LockDB(ctx context.Context) error { + return nil +} + +// UnlockDB is a method from dbump.Migrator interface. +func (s *Migrator) UnlockDB(ctx context.Context) error { + return nil +} + +// Version is a method from dbump.Migrator interface. +func (s *Migrator) Version(ctx context.Context) (version int, err error) { + query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", s.cfg.tableName) + row := s.conn.QueryRowContext(ctx, query) + err = row.Scan(&version) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return version, err +} + +// DoStep is a method from dbump.Migrator interface. +func (s *Migrator) DoStep(ctx context.Context, step dbump.Step) error { + if step.DisableTx { + if _, err := s.conn.ExecContext(ctx, step.Query); err != nil { + return err + } + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName) + _, err := s.conn.ExecContext(ctx, query, step.Version) + return err + } + return s.inTx(ctx, step) +} + +func (s *Migrator) inTx(ctx context.Context, step dbump.Step) error { + tx, err := s.conn.BeginTx(ctx, nil) + if err != nil { + return err + } + + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName) + _, err = s.conn.ExecContext(ctx, query, step.Version) + // fmt.Printf("args: %s %d\n", query, step.Version) + if err != nil { + _ = tx.Rollback() + return err + } + + _, err = tx.ExecContext(ctx, step.Query) + if err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func hashTableName(s string) int64 { + h := fnv.New64() + h.Write([]byte(s)) + return int64(h.Sum64()) +} diff --git a/dbump_sqlite/sqlite_test.go b/dbump_sqlite/sqlite_test.go new file mode 100644 index 0000000..a8dc7cd --- /dev/null +++ b/dbump_sqlite/sqlite_test.go @@ -0,0 +1,180 @@ +package dbump_sqlite + +import ( + "context" + "database/sql" + "os" + "reflect" + "testing" + + "github.com/cristalhq/dbump" + "github.com/cristalhq/dbump/tests" + + _ "github.com/mattn/go-sqlite3" +) + +func TestSQLite_Simple(t *testing.T) { + conn := newTestConn(t) + + m := NewMigrator(conn, Config{}) + l := dbump.NewSliceLoader([]*dbump.Migration{ + { + ID: 1, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + { + ID: 2, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + { + ID: 3, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + }) + + errRun := dbump.Run(context.Background(), dbump.Config{ + Migrator: m, + Loader: l, + Mode: dbump.ModeApplyAll, + }) + mustOK(t, errRun) +} + +func TestNonDefaultSchemaTable(t *testing.T) { + conn := newTestConn(t) + + testCases := []struct { + name string + schema string + table string + wantTableName string + wantLockNum int64 + }{ + { + name: "all empty", + schema: "", + table: "", + wantTableName: "_dbump_log", + wantLockNum: -3987518601082986461, + }, + { + name: "schema set", + schema: "test_schema", + table: "", + wantTableName: "test_schema._dbump_log", + wantLockNum: 1417388815471108263, + }, + { + name: "table set", + schema: "", + table: "test_table", + wantTableName: "test_table", + wantLockNum: 8712390964734167792, + }, + { + name: "schema and table set", + schema: "test_schema", + table: "test_table", + wantTableName: "test_schema.test_table", + wantLockNum: 4631047095544292572, + }, + } + + for _, tc := range testCases { + m := NewMigrator(conn, Config{ + Schema: tc.schema, + Table: tc.table, + }) + mustEqual(t, m.cfg.tableName, tc.wantTableName) + mustEqual(t, m.cfg.lockNum, tc.wantLockNum) + } +} + +func TestMigrate_ApplyAll(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).ApplyAll(t) +} + +func TestMigrate_ApplyOne(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).ApplyOne(t) +} + +func TestMigrate_ApplyAllWhenFull(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).ApplyAllWhenFull(t) +} + +func TestMigrate_RevertOne(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).RevertOne(t) +} + +func TestMigrate_RevertAllWhenEmpty(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).RevertAllWhenEmpty(t) +} + +func TestMigrate_RevertAll(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).RevertAll(t) +} + +func TestMigrate_Redo(t *testing.T) { + conn := newTestConn(t) + newSuite(conn).Redo(t) +} + +func TestMigrate_Drop(t *testing.T) { + conn := newTestConn(t) + suite := newSuite(conn) + suite.SkipCleanup = true + suite.Drop(t) +} + +var envPath = os.Getenv("DBUMP_SQLITE_PATH") + +func newTestConn(tb testing.TB) *sql.DB { + path := "./db.sqlitedb" + if envPath != "" { + path = envPath + } + + // ignore error if it doesn't exist. + _ = os.Remove(path) + + conn, err := sql.Open("sqlite3", path) + mustOK(tb, err) + + tb.Cleanup(func() { + _ = os.Remove(path) + }) + return conn +} + +func newSuite(conn *sql.DB) *tests.MigratorSuite { + m := NewMigrator(conn, Config{}) + suite := tests.NewMigratorSuite(m) + suite.ApplyTmpl = "CREATE TABLE %[1]s_%[2]d (id INT);" + suite.RevertTmpl = "DROP TABLE %[1]s_%[2]d;" + suite.CleanMigTmpl = "DROP TABLE IF EXISTS %[1]s_%[2]d;" + suite.CleanTest = "DELETE FROM _dbump_log;" + return suite +} + +func mustOK(tb testing.TB, err error) { + tb.Helper() + if err != nil { + tb.Fatal(err) + } +} + +func mustEqual(tb testing.TB, got, want interface{}) { + tb.Helper() + if !reflect.DeepEqual(got, want) { + tb.Fatalf("\nhave %+v\nwant %+v", got, want) + } +} diff --git a/tests/tests.go b/tests/tests.go index f87db63..879b29c 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -15,6 +15,7 @@ type MigratorSuite struct { RevertTmpl string CleanMigTmpl string CleanTest string + SkipCleanup bool } func NewMigratorSuite(m dbump.Migrator) *MigratorSuite { @@ -234,10 +235,15 @@ func (suite *MigratorSuite) genMigrations(tb testing.TB, num int, testname strin } tb.Cleanup(func() { + if suite.SkipCleanup { + return + } + for i := 1; i <= num; i++ { query := fmt.Sprintf(suite.CleanMigTmpl, testname, i) failIfErr(tb, suite.migrator.DoStep(context.Background(), dbump.Step{ - Query: query, + Version: num - i - 1, + Query: query, })) } failIfErr(tb, suite.migrator.DoStep(context.Background(), dbump.Step{