From 835a81249c5d88a0745ea9ef6d12031943e634f9 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 19 Jun 2024 19:45:14 +0200 Subject: [PATCH] Add SQLite --- dbump_sqlite/go.mod | 10 +++ dbump_sqlite/go.sum | 2 + dbump_sqlite/sqlite.go | 138 +++++++++++++++++++++++++++++ dbump_sqlite/sqlite_test.go | 169 ++++++++++++++++++++++++++++++++++++ 4 files changed, 319 insertions(+) create mode 100644 dbump_sqlite/go.mod create mode 100644 dbump_sqlite/go.sum create mode 100644 dbump_sqlite/sqlite.go create mode 100644 dbump_sqlite/sqlite_test.go diff --git a/dbump_sqlite/go.mod b/dbump_sqlite/go.mod new file mode 100644 index 0000000..26919a5 --- /dev/null +++ b/dbump_sqlite/go.mod @@ -0,0 +1,10 @@ +module github.com/cristalhq/dbump/dbump_sqlite + +go 1.17 + +require ( + github.com/cristalhq/dbump v0.14.0 + github.com/mattn/go-sqlite3 v1.14.22 +) + +replace github.com/cristalhq/dbump => ../ diff --git a/dbump_sqlite/go.sum b/dbump_sqlite/go.sum new file mode 100644 index 0000000..e8d092a --- /dev/null +++ b/dbump_sqlite/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/dbump_sqlite/sqlite.go b/dbump_sqlite/sqlite.go new file mode 100644 index 0000000..ee47085 --- /dev/null +++ b/dbump_sqlite/sqlite.go @@ -0,0 +1,138 @@ +package dbump_sqlite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "hash/fnv" + + "github.com/cristalhq/dbump" +) + +var _ dbump.Migrator = &Migrator{} + +// Migrator to migrate Postgres. +type Migrator struct { + conn *sql.DB + cfg Config +} + +// Config for the migrator. +type Config struct { + // Schema for the dbump version table. Default is empty which means "public" schema. + Schema string + // Table for the dbump version table. Default is empty which means "_dbump_log" table. + Table string + + // [schema.]table + tableName string + // to prevent multiple migrations running at the same time + lockNum int64 + + _ struct{} // enforce explicit field names. +} + +// NewMigrator instantiates new Migrator. +func NewMigrator(conn *sql.DB, cfg Config) *Migrator { + if cfg.Schema != "" { + cfg.Schema += "." + } + if cfg.Table == "" { + cfg.Table = "_dbump_log" + } + + cfg.tableName = cfg.Schema + cfg.Table + cfg.lockNum = hashTableName(cfg.tableName) + + return &Migrator{ + conn: conn, + cfg: cfg, + } +} + +// Init is a method from dbump.Migrator interface. +func (s *Migrator) Init(ctx context.Context) error { + query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + version INTEGER NOT NULL, + created_at TEXT NOT NULL +);`, s.cfg.tableName) + + _, err := s.conn.ExecContext(ctx, query) + return err +} + +// Drop is a method from dbump.Migrator interface. +func (s *Migrator) Drop(ctx context.Context) error { + query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, s.cfg.tableName) + + // TODO: probably should ignore error for this query + // if s.cfg.Schema != "" { + // query = fmt.Sprintf(`DROP SCHEMA IF EXISTS %s RESTRICT;`, s.cfg.Schema) + // } + _, err := s.conn.ExecContext(ctx, query) + return err +} + +// LockDB is a method from dbump.Migrator interface. +func (s *Migrator) LockDB(ctx context.Context) error { + return nil +} + +// UnlockDB is a method from dbump.Migrator interface. +func (s *Migrator) UnlockDB(ctx context.Context) error { + return nil +} + +// Version is a method from dbump.Migrator interface. +func (s *Migrator) Version(ctx context.Context) (version int, err error) { + query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", s.cfg.tableName) + row := s.conn.QueryRowContext(ctx, query) + err = row.Scan(&version) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return version, err +} + +// DoStep is a method from dbump.Migrator interface. +func (s *Migrator) DoStep(ctx context.Context, step dbump.Step) error { + if step.DisableTx { + if _, err := s.conn.ExecContext(ctx, step.Query); err != nil { + return err + } + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName) + _, err := s.conn.ExecContext(ctx, query, step.Version) + return err + } + return s.inTx(ctx, step) +} + +func (s *Migrator) inTx(ctx context.Context, step dbump.Step) error { + tx, err := s.conn.BeginTx(ctx, nil) + if err != nil { + return err + } + + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName) + _, err = s.conn.ExecContext(ctx, query, step.Version) + // fmt.Printf("args: %s %d\n", query, step.Version) + if err != nil { + _ = tx.Rollback() + return err + } + + _, err = tx.ExecContext(ctx, step.Query) + if err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func hashTableName(s string) int64 { + h := fnv.New64() + h.Write([]byte(s)) + return int64(h.Sum64()) +} diff --git a/dbump_sqlite/sqlite_test.go b/dbump_sqlite/sqlite_test.go new file mode 100644 index 0000000..ca0389a --- /dev/null +++ b/dbump_sqlite/sqlite_test.go @@ -0,0 +1,169 @@ +package dbump_sqlite + +import ( + "context" + "database/sql" + "os" + "reflect" + "testing" + + "github.com/cristalhq/dbump" + "github.com/cristalhq/dbump/tests" + + _ "github.com/mattn/go-sqlite3" +) + +var conn *sql.DB + +func init() { + path := os.Getenv("DBUMP_SQLITE_PATH") + if path == "" { + path = "./db.sqlitedb" // + time.Now().String() + } + + var err error + conn, err = sql.Open("sqlite3", path) + if err != nil { + panic(err) + } +} + +func TestSQLite_Simple(t *testing.T) { + m := NewMigrator(conn, Config{}) + l := dbump.NewSliceLoader([]*dbump.Migration{ + { + ID: 1, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + { + ID: 2, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + { + ID: 3, + Apply: "SELECT 1;", + Revert: "SELECT 1;", + }, + }) + + errRun := dbump.Run(context.Background(), dbump.Config{ + Migrator: m, + Loader: l, + Mode: dbump.ModeApplyAll, + }) + failIfErr(t, errRun) +} + +func TestNonDefaultSchemaTable(t *testing.T) { + testCases := []struct { + name string + schema string + table string + wantTableName string + wantLockNum int64 + }{ + { + name: "all empty", + schema: "", + table: "", + wantTableName: "_dbump_log", + wantLockNum: -3987518601082986461, + }, + { + name: "schema set", + schema: "test_schema", + table: "", + wantTableName: "test_schema._dbump_log", + wantLockNum: 1417388815471108263, + }, + { + name: "table set", + schema: "", + table: "test_table", + wantTableName: "test_table", + wantLockNum: 8712390964734167792, + }, + { + name: "schema and table set", + schema: "test_schema", + table: "test_table", + wantTableName: "test_schema.test_table", + wantLockNum: 4631047095544292572, + }, + } + + for _, tc := range testCases { + m := NewMigrator(conn, Config{ + Schema: tc.schema, + Table: tc.table, + }) + mustEqual(t, m.cfg.tableName, tc.wantTableName) + mustEqual(t, m.cfg.lockNum, tc.wantLockNum) + } +} + +func TestMigrate_ApplyAll(t *testing.T) { + newSuite().ApplyAll(t) +} + +func TestMigrate_ApplyOne(t *testing.T) { + newSuite().ApplyOne(t) +} + +func TestMigrate_ApplyAllWhenFull(t *testing.T) { + newSuite().ApplyAllWhenFull(t) +} + +func TestMigrate_RevertOne(t *testing.T) { + newSuite().RevertOne(t) +} + +func TestMigrate_RevertAllWhenEmpty(t *testing.T) { + newSuite().RevertAllWhenEmpty(t) +} + +func TestMigrate_RevertAll(t *testing.T) { + newSuite().RevertAll(t) +} + +func TestMigrate_Redo(t *testing.T) { + newSuite().Redo(t) +} + +func TestMigrate_Drop(t *testing.T) { + // t.Skip() + newSuite().Drop(t) +} + +func newSuite() *tests.MigratorSuite { + m := NewMigrator(conn, Config{}) + suite := tests.NewMigratorSuite(m) + suite.ApplyTmpl = "CREATE TABLE %[1]s_%[2]d (id INT);" + suite.RevertTmpl = "DROP TABLE %[1]s_%[2]d;" + suite.CleanMigTmpl = "DROP TABLE IF EXISTS %[1]s_%[2]d;" + suite.CleanTest = "DELETE FROM _dbump_log;" + return suite +} + +func failIfErr(tb testing.TB, err error) { + tb.Helper() + if err != nil { + tb.Fatal(err) + } +} + +func mustEqual(tb testing.TB, got, want interface{}) { + tb.Helper() + if !reflect.DeepEqual(got, want) { + tb.Fatalf("\nhave %+v\nwant %+v", got, want) + } +} + +func envOrDef(env, def string) string { + if val := os.Getenv(env); val != "" { + return val + } + return def +}