Skip to content

Commit

Permalink
Add SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jun 19, 2024
1 parent 64bd95d commit 835a812
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 0 deletions.
10 changes: 10 additions & 0 deletions dbump_sqlite/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module github.com/cristalhq/dbump/dbump_sqlite

go 1.17

require (
github.com/cristalhq/dbump v0.14.0
github.com/mattn/go-sqlite3 v1.14.22
)

replace github.com/cristalhq/dbump => ../
2 changes: 2 additions & 0 deletions dbump_sqlite/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
138 changes: 138 additions & 0 deletions dbump_sqlite/sqlite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package dbump_sqlite

import (
"context"
"database/sql"
"errors"
"fmt"
"hash/fnv"

"github.com/cristalhq/dbump"
)

var _ dbump.Migrator = &Migrator{}

// Migrator to migrate Postgres.
type Migrator struct {
conn *sql.DB
cfg Config
}

// Config for the migrator.
type Config struct {
// Schema for the dbump version table. Default is empty which means "public" schema.
Schema string
// Table for the dbump version table. Default is empty which means "_dbump_log" table.
Table string

// [schema.]table
tableName string
// to prevent multiple migrations running at the same time
lockNum int64

_ struct{} // enforce explicit field names.
}

// NewMigrator instantiates new Migrator.
func NewMigrator(conn *sql.DB, cfg Config) *Migrator {
if cfg.Schema != "" {
cfg.Schema += "."
}
if cfg.Table == "" {
cfg.Table = "_dbump_log"
}

cfg.tableName = cfg.Schema + cfg.Table
cfg.lockNum = hashTableName(cfg.tableName)

return &Migrator{
conn: conn,
cfg: cfg,
}
}

// Init is a method from dbump.Migrator interface.
func (s *Migrator) Init(ctx context.Context) error {
query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
version INTEGER NOT NULL,
created_at TEXT NOT NULL
);`, s.cfg.tableName)

_, err := s.conn.ExecContext(ctx, query)
return err
}

// Drop is a method from dbump.Migrator interface.
func (s *Migrator) Drop(ctx context.Context) error {
query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, s.cfg.tableName)

// TODO: probably should ignore error for this query
// if s.cfg.Schema != "" {
// query = fmt.Sprintf(`DROP SCHEMA IF EXISTS %s RESTRICT;`, s.cfg.Schema)
// }
_, err := s.conn.ExecContext(ctx, query)
return err
}

// LockDB is a method from dbump.Migrator interface.
func (s *Migrator) LockDB(ctx context.Context) error {
return nil
}

// UnlockDB is a method from dbump.Migrator interface.
func (s *Migrator) UnlockDB(ctx context.Context) error {
return nil
}

// Version is a method from dbump.Migrator interface.
func (s *Migrator) Version(ctx context.Context) (version int, err error) {
query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", s.cfg.tableName)
row := s.conn.QueryRowContext(ctx, query)
err = row.Scan(&version)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return version, err
}

// DoStep is a method from dbump.Migrator interface.
func (s *Migrator) DoStep(ctx context.Context, step dbump.Step) error {
if step.DisableTx {
if _, err := s.conn.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName)
_, err := s.conn.ExecContext(ctx, query, step.Version)
return err
}
return s.inTx(ctx, step)
}

func (s *Migrator) inTx(ctx context.Context, step dbump.Step) error {
tx, err := s.conn.BeginTx(ctx, nil)
if err != nil {
return err
}

query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, STRFTIME('%%Y-%%m-%%d %%H:%%M:%%f', 'NOW'));", s.cfg.tableName)
_, err = s.conn.ExecContext(ctx, query, step.Version)
// fmt.Printf("args: %s %d\n", query, step.Version)
if err != nil {
_ = tx.Rollback()
return err
}

_, err = tx.ExecContext(ctx, step.Query)
if err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}

func hashTableName(s string) int64 {
h := fnv.New64()
h.Write([]byte(s))
return int64(h.Sum64())
}
169 changes: 169 additions & 0 deletions dbump_sqlite/sqlite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package dbump_sqlite

import (
"context"
"database/sql"
"os"
"reflect"
"testing"

"github.com/cristalhq/dbump"
"github.com/cristalhq/dbump/tests"

_ "github.com/mattn/go-sqlite3"
)

var conn *sql.DB

func init() {
path := os.Getenv("DBUMP_SQLITE_PATH")
if path == "" {
path = "./db.sqlitedb" // + time.Now().String()
}

var err error
conn, err = sql.Open("sqlite3", path)
if err != nil {
panic(err)
}
}

func TestSQLite_Simple(t *testing.T) {
m := NewMigrator(conn, Config{})
l := dbump.NewSliceLoader([]*dbump.Migration{
{
ID: 1,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
{
ID: 2,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
{
ID: 3,
Apply: "SELECT 1;",
Revert: "SELECT 1;",
},
})

errRun := dbump.Run(context.Background(), dbump.Config{
Migrator: m,
Loader: l,
Mode: dbump.ModeApplyAll,
})
failIfErr(t, errRun)
}

func TestNonDefaultSchemaTable(t *testing.T) {
testCases := []struct {
name string
schema string
table string
wantTableName string
wantLockNum int64
}{
{
name: "all empty",
schema: "",
table: "",
wantTableName: "_dbump_log",
wantLockNum: -3987518601082986461,
},
{
name: "schema set",
schema: "test_schema",
table: "",
wantTableName: "test_schema._dbump_log",
wantLockNum: 1417388815471108263,
},
{
name: "table set",
schema: "",
table: "test_table",
wantTableName: "test_table",
wantLockNum: 8712390964734167792,
},
{
name: "schema and table set",
schema: "test_schema",
table: "test_table",
wantTableName: "test_schema.test_table",
wantLockNum: 4631047095544292572,
},
}

for _, tc := range testCases {
m := NewMigrator(conn, Config{
Schema: tc.schema,
Table: tc.table,
})
mustEqual(t, m.cfg.tableName, tc.wantTableName)
mustEqual(t, m.cfg.lockNum, tc.wantLockNum)
}
}

func TestMigrate_ApplyAll(t *testing.T) {
newSuite().ApplyAll(t)
}

func TestMigrate_ApplyOne(t *testing.T) {
newSuite().ApplyOne(t)
}

func TestMigrate_ApplyAllWhenFull(t *testing.T) {
newSuite().ApplyAllWhenFull(t)
}

func TestMigrate_RevertOne(t *testing.T) {
newSuite().RevertOne(t)
}

func TestMigrate_RevertAllWhenEmpty(t *testing.T) {
newSuite().RevertAllWhenEmpty(t)
}

func TestMigrate_RevertAll(t *testing.T) {
newSuite().RevertAll(t)
}

func TestMigrate_Redo(t *testing.T) {
newSuite().Redo(t)
}

func TestMigrate_Drop(t *testing.T) {
// t.Skip()
newSuite().Drop(t)
}

func newSuite() *tests.MigratorSuite {
m := NewMigrator(conn, Config{})
suite := tests.NewMigratorSuite(m)
suite.ApplyTmpl = "CREATE TABLE %[1]s_%[2]d (id INT);"
suite.RevertTmpl = "DROP TABLE %[1]s_%[2]d;"
suite.CleanMigTmpl = "DROP TABLE IF EXISTS %[1]s_%[2]d;"
suite.CleanTest = "DELETE FROM _dbump_log;"
return suite
}

func failIfErr(tb testing.TB, err error) {
tb.Helper()
if err != nil {
tb.Fatal(err)
}
}

func mustEqual(tb testing.TB, got, want interface{}) {
tb.Helper()
if !reflect.DeepEqual(got, want) {
tb.Fatalf("\nhave %+v\nwant %+v", got, want)
}
}

func envOrDef(env, def string) string {
if val := os.Getenv(env); val != "" {
return val
}
return def
}

0 comments on commit 835a812

Please sign in to comment.