Skip to content

Commit

Permalink
feat: support postgres (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent authored Jan 6, 2024
2 parents 68ccc70 + bb009ab commit a236137
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 29 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ options:
- `diff` subcommand
- dialect
- [ ] Support `mysql`
- [ ] Support `postgres`
- [x] Support `postgres`
- [x] Support `cockroachdb`
- [ ] Support `spanner`
- [ ] Support `sqlite3`
- `apply` subcommand
- dialect
- [ ] Support `mysql`
- [ ] Support `postgres`
- [x] Support `postgres`
- [x] Support `cockroachdb`
- [ ] Support `spanner`
- [ ] Support `sqlite3`
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module github.com/kunitsucom/ddlctl
go 1.21.5

require (
github.com/kunitsucom/util.go v0.0.60-rc.4
github.com/kunitsucom/util.go v0.0.60-rc.6
github.com/lib/pq v1.10.9
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
github.com/kunitsucom/util.go v0.0.60-rc.4 h1:TGqi1YeOgfh4dScSfI6C14HDhoX2qCkbWEak5ZIV1Q4=
github.com/kunitsucom/util.go v0.0.60-rc.4/go.mod h1:bYFf2JvRqVF1brBtpdt3xkkTGJBxmYBxZlItrc/lf7Y=
github.com/kunitsucom/util.go v0.0.60-rc.6 h1:aiyBwQzzqxzQzWz5fZCMoj1bOWPA7iAVgju85/zXmA8=
github.com/kunitsucom/util.go v0.0.60-rc.6/go.mod h1:bYFf2JvRqVF1brBtpdt3xkkTGJBxmYBxZlItrc/lf7Y=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
58 changes: 58 additions & 0 deletions internal/integrationtest/integrationtest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package integrationtest_test

import (
"context"
"os"
"testing"

cliz "github.com/kunitsucom/util.go/exp/cli"
testingz "github.com/kunitsucom/util.go/testing"
"github.com/kunitsucom/util.go/testing/assert"
"github.com/kunitsucom/util.go/testing/require"

"github.com/kunitsucom/ddlctl/internal/ddlctl/fixture"
"github.com/kunitsucom/ddlctl/pkg/ddlctl"
)

//nolint:paralleltest
func Test_ddlctl_diff(t *testing.T) {
t.Run("success,go,postgres", func(t *testing.T) {
cmd := fixture.Cmd()
args, err := cmd.Parse([]string{
"--lang=go",
"--dialect=postgres",
"postgres_before.sql",
"postgres_after.sql",
})
require.NoError(t, err)
ctx := cliz.WithContext(context.Background(), cmd)

backup := os.Stdout
t.Cleanup(func() { os.Stdout = backup })

w, closeFunc, err := testingz.NewFileWriter(t)
require.NoError(t, err)

os.Stdout = w
{
err := ddlctl.Diff(ctx, args)
require.NoError(t, err)
}
result := closeFunc()

const expected = `-- -
-- +description TEXT NOT NULL
ALTER TABLE public.test_groups ADD COLUMN description TEXT NOT NULL;
-- -name TEXT NOT NULL
-- +
ALTER TABLE public.test_users DROP COLUMN name;
-- -
-- +username TEXT NOT NULL
ALTER TABLE public.test_users ADD COLUMN username TEXT NOT NULL;
`

actual := result.String()

assert.Equal(t, expected, actual)
})
}
13 changes: 13 additions & 0 deletions internal/integrationtest/postgres_after.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE TABLE public.test_groups (
group_id UUID NOT NULL,
group_name TEXT NOT NULL,
description TEXT NOT NULL,
PRIMARY KEY (group_id)
);
CREATE TABLE public.test_users (
user_id UUID NOT NULL,
username TEXT NOT NULL,
group_id UUID NOT NULL,
PRIMARY KEY (user_id)
);
CREATE INDEX test_users_idx_on_group_id ON public.test_users (group_id);
12 changes: 12 additions & 0 deletions internal/integrationtest/postgres_before.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE TABLE IF NOT EXISTS public.test_groups (
group_id UUID NOT NULL,
group_name TEXT NOT NULL,
PRIMARY KEY (group_id)
);
CREATE TABLE IF NOT EXISTS public.test_users (
user_id UUID NOT NULL,
name TEXT NOT NULL,
group_id UUID NOT NULL,
PRIMARY KEY (user_id)
);
CREATE INDEX IF NOT EXISTS test_users_idx_on_group_id ON public.test_users (group_id);
13 changes: 9 additions & 4 deletions pkg/ddlctl/ddlctl_apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

//nolint:cyclop,funlen
func Apply(ctx context.Context, args []string) error {
func Apply(ctx context.Context, args []string) (err error) {
if _, err := config.Load(ctx); err != nil {
return errorz.Errorf("config.Load: %w", err)
}
Expand All @@ -37,13 +37,14 @@ func Apply(ctx context.Context, args []string) error {
if err := diff(buf, left, right); err != nil {
return errorz.Errorf("diff: %w", err)
}
q := buf.String()

msg := `
ddlctl will exec the following DDL queries:
-- 8< --
` + buf.String() + `
` + q + `
-- >8 --
Expand Down Expand Up @@ -73,9 +74,13 @@ Enter a value: `
if err != nil {
return errorz.Errorf("sqlz.OpenContext: %w", err)
}
defer db.Close()
defer func() {
if cerr := db.Close(); err == nil && cerr != nil {
err = errorz.Errorf("db.Close: %w", cerr)
}
}()

if _, err := db.ExecContext(ctx, buf.String()); err != nil {
if _, err := db.ExecContext(ctx, q); err != nil {
return errorz.Errorf("db.ExecContext: %w", err)
}

Expand Down
76 changes: 56 additions & 20 deletions pkg/ddlctl/ddlctl_diff.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ddlctl

import (
"bytes"
"context"
"io"
"os"
Expand All @@ -12,15 +11,18 @@ import (
errorz "github.com/kunitsucom/util.go/errors"
crdbddl "github.com/kunitsucom/util.go/exp/database/sql/ddl/cockroachdb"
pgddl "github.com/kunitsucom/util.go/exp/database/sql/ddl/postgres"
crdbutil "github.com/kunitsucom/util.go/exp/database/sql/util/cockroachdb"
myutil "github.com/kunitsucom/util.go/exp/database/sql/util/mysql"
pgutil "github.com/kunitsucom/util.go/exp/database/sql/util/postgres"
osz "github.com/kunitsucom/util.go/os"

"github.com/kunitsucom/ddlctl/internal/config"
pgddlgen "github.com/kunitsucom/ddlctl/internal/ddlctl/ddl/dialect/postgres"
"github.com/kunitsucom/ddlctl/internal/logs"
apperr "github.com/kunitsucom/ddlctl/pkg/errors"
)

const (
_mysql = "mysql"
_postgres = "postgres"
_cockroachdb = "cockroachdb"
)
Expand All @@ -39,15 +41,10 @@ func Diff(ctx context.Context, args []string) error {
return errorz.Errorf("resolve: %w", err)
}

buf := bytes.NewBuffer(nil)
if err := diff(buf, left, right); err != nil {
if err := diff(os.Stdout, left, right); err != nil {
return errorz.Errorf("diff: %w", err)
}

if _, err := io.Copy(os.Stdout, buf); err != nil {
return errorz.Errorf("io.Copy: %w", err)
}

return nil
}

Expand Down Expand Up @@ -85,7 +82,7 @@ func resolve(ctx context.Context, dialect, left, right string) (srcDDL string, d

ddl, err := dumpCreateStmts(ctx, dialect, right)
if err != nil {
return "", "", errorz.Errorf("sqlz.OpenContext: %w", err)
return "", "", errorz.Errorf("dumpCreateStmts: %w", err)
}
dstDDL = ddl
case osz.IsFile(right): // NOTE: expect SQL file
Expand All @@ -105,25 +102,59 @@ func resolve(ctx context.Context, dialect, left, right string) (srcDDL string, d
return srcDDL, dstDDL, nil
}

//nolint:cyclop
func dumpCreateStmts(ctx context.Context, dialect string, dsn string) (ddl string, err error) {
switch dialect {
case _cockroachdb:
case _mysql:
db, err := sqlz.OpenContext(ctx, _mysql, dsn)
if err != nil {
return "", errorz.Errorf("sqlz.OpenContext: %w", err)
}
defer func() {
if cerr := db.Close(); err == nil && cerr != nil {
err = errorz.Errorf("db.Close: %w", cerr)
}
}()

ddl, err := myutil.ShowCreateAllTables(ctx, db)
if err != nil {
return "", errorz.Errorf("pgutil.ShowCreateAllTables: %w", err)
}

return ddl, nil
case _postgres:
db, err := sqlz.OpenContext(ctx, _postgres, dsn)
if err != nil {
return "", errorz.Errorf("sqlz.OpenContext: %w", err)
}
defer db.Close()
defer func() {
if cerr := db.Close(); err == nil && cerr != nil {
err = errorz.Errorf("db.Close: %w", cerr)
}
}()

type CreateTableStatement struct {
CreateStatement string `db:"create_statement"`
ddl, err := pgutil.ShowCreateAllTables(ctx, db)
if err != nil {
return "", errorz.Errorf("pgutil.ShowCreateAllTables: %w", err)
}
v := new([]*CreateTableStatement)
if err := sqlz.NewDB(db).QueryContext(ctx, v, "SHOW CREATE ALL TABLES;"); err != nil {
return "", errorz.Errorf("sqlz.NewDB.QueryContext: %w", err)

return ddl, nil
case _cockroachdb:
db, err := sqlz.OpenContext(ctx, _postgres, dsn)
if err != nil {
return "", errorz.Errorf("sqlz.OpenContext: %w", err)
}
for _, stmt := range *v {
ddl += stmt.CreateStatement
defer func() {
if cerr := db.Close(); err == nil && cerr != nil {
err = errorz.Errorf("db.Close: %w", cerr)
}
}()

ddl, err := crdbutil.ShowCreateAllTables(ctx, db)
if err != nil {
return "", errorz.Errorf("crdbutil.ShowCreateAllTables: %w", err)
}

return ddl, nil
default:
return "", errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported)
Expand All @@ -150,7 +181,10 @@ func diff(out io.Writer, src, dst string) error {
logs.Debug.Printf("dst: %q", dst)

switch dialect := config.Dialect(); dialect {
case pgddlgen.Dialect:
case _mysql:
// TODO: implement
return errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported)
case _postgres:
leftDDL, err := pgddl.NewParser(pgddl.NewLexer(src)).Parse()
if err != nil {
return errorz.Errorf("pgddl.NewParser: %w", err)
Expand All @@ -165,7 +199,9 @@ func diff(out io.Writer, src, dst string) error {
return errorz.Errorf("pgddl.Diff: %w", err)
}

os.Stdout.WriteString(result.String())
if _, err := io.WriteString(out, result.String()); err != nil {
return errorz.Errorf("io.WriteString: %w", err)
}

return nil
case _cockroachdb:
Expand Down

0 comments on commit a236137

Please sign in to comment.