Skip to content

Commit

Permalink
Improve dir schema source errors (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
bplunkett-stripe committed Sep 9, 2024
1 parent 6937348 commit 5fe8259
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 45 deletions.
41 changes: 4 additions & 37 deletions cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"database/sql"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -223,19 +221,12 @@ func parsePlanConfig(p planFlags) (planConfig, error) {

func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
if len(p.schemaDirs) > 0 {
var ddl []string
// Ordering of execution of schema SQL can be guaranteed by:
// - Splitting across multiple directories and using multiple schema dir flags
// - Relying on lexical order of SQL files
for _, schemaDir := range p.schemaDirs {
stmts, err := getDDLFromPath(schemaDir)
return func() (diff.SchemaSource, io.Closer, error) {
schemaSource, err := diff.DirSchemaSource(p.schemaDirs)
if err != nil {
return nil, fmt.Errorf("getting DDL from path %q: %w", schemaDir, err)
return nil, nil, err
}
ddl = append(ddl, stmts...)
}
return func() (diff.SchemaSource, io.Closer, error) {
return diff.DDLSchemaSource(ddl), nil, nil
return schemaSource, nil, nil
}, nil
}

Expand Down Expand Up @@ -434,30 +425,6 @@ func applyPlanModifiers(
return plan, nil
}

// getDDLFromPath reads all .sql files under the given path (including sub-directories) and returns the DDL
// in lexical order.
func getDDLFromPath(path string) ([]string, error) {
var ddl []string
if err := filepath.Walk(path, func(path string, entry os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walking path %q: %w", path, err)
}
if strings.ToLower(filepath.Ext(entry.Name())) != ".sql" {
return nil
}

if stmts, err := os.ReadFile(path); err != nil {
return fmt.Errorf("reading file %q: %w", entry.Name(), err)
} else {
ddl = append(ddl, string(stmts))
}
return nil
}); err != nil {
return nil, err
}
return ddl, nil
}

func planToPrettyS(plan diff.Plan) string {
sb := strings.Builder{}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)

func databaseSchemaSourcePlanFactory(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
func databaseSchemaSourcePlan(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
newSchemaDb, err := tempDbFactory.Create(ctx)
if err != nil {
return diff.Plan{}, fmt.Errorf("creating temp database: %w", err)
Expand Down Expand Up @@ -38,8 +38,29 @@ func databaseSchemaSourcePlanFactory(ctx context.Context, connPool sqldb.Queryab
return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
}

func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
return func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
// Clone the opts so we don't modify the original.
opts = append([]diff.PlanOpt(nil), opts...)
opts = append(opts, diff.WithTempDbFactory(tempDbFactory))

if len(newSchemaDDL) != 0 {
panic("newSchemaDDL should be empty for dir schema sources")
}

schemaSource, err := diff.DirSchemaSource(schemaDirs)
if err != nil {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}

return diff.Generate(ctx, connPool, schemaSource, opts...)
}
}

var databaseSchemaSourceTestCases = []acceptanceTestCase{
{
planFactory: databaseSchemaSourcePlan,

name: "Drop partitioned table, Add partitioned table with local keys",
oldSchemaDDL: []string{
`
Expand Down Expand Up @@ -119,8 +140,48 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeDeletesData,
},
},
{
planFactory: dirSchemaSourcePlanFactory([]string{"testdata/dirsrc_happy_path/schema_0", "testdata/dirsrc_happy_path/schema_1"}),

name: "Dir src - happy path",
oldSchemaDDL: []string{
`
CREATE TABLE foobar( );
`,
},

expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeIndexBuild,
},
expectedDBSchemaDDL: []string{
`
CREATE TYPE color AS ENUM ('red', 'green', 'blue');
CREATE TABLE foobar(
color color,
id varchar(255) PRIMARY KEY
);
CREATE TABLE foobar_fk(
id TEXT REFERENCES foobar(id)
);
CREATE TABLE fizzbuzz(
id TEXT,
primary_color color,
other_color color
); `,
},
},
{
planFactory: dirSchemaSourcePlanFactory([]string{"testdata/dirsrc_invalid_sql/schema_0"}),

name: "Dir src - invalid sql",
oldSchemaDDL: []string{
`
CREATE TABLE foobar( );
`,
},

planFactory: databaseSchemaSourcePlanFactory,
expectedPlanErrorContains: "testdata/dirsrc_invalid_sql/schema_0/1.sql",
},
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TYPE color AS ENUM ('red', 'green', 'blue');
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE foobar (
id varchar(255) PRIMARY KEY,
color color
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE foobar_fk (
id TEXT REFERENCES foobar (id)
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE fizzbuzz (
id TEXT,
primary_color COLOR,
other_color COLOR
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE foobar (
id SERIAL PRIMARY KEY
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE fizzbuzz (
id TEXT REFERENCES non_existent_table (id)
);
85 changes: 79 additions & 6 deletions pkg/diff/schema_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package diff
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/stripe/pg-schema-diff/internal/schema"
"github.com/stripe/pg-schema-diff/pkg/log"
Expand All @@ -20,13 +23,79 @@ type SchemaSource interface {
GetSchema(ctx context.Context, deps schemaSourcePlanDeps) (schema.Schema, error)
}

type ddlSchemaSource struct {
ddl []string
type (
ddlStatement struct {
// stmt is the DDL statement to run.
stmt string
// file is an optional field that can be used to store the file name from which the DDL was read.
file string
}

ddlSchemaSource struct {
ddl []ddlStatement
}
)

// DirSchemaSource returns a SchemaSource that returns a schema based on the provided directories. You must provide a tempDBFactory
// via the WithTempDbFactory option.
func DirSchemaSource(dirs []string) (SchemaSource, error) {
var ddl []ddlStatement
for _, dir := range dirs {
stmts, err := getDDLFromPath(dir)
if err != nil {
return &ddlSchemaSource{}, err
}
ddl = append(ddl, stmts...)

}
return &ddlSchemaSource{
ddl: ddl,
}, nil
}

// getDDLFromPath reads all .sql files under the given path (including sub-directories) and returns the DDL
// in lexical order.
func getDDLFromPath(path string) ([]ddlStatement, error) {
var ddl []ddlStatement
if err := filepath.Walk(path, func(path string, entry os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walking path %q: %w", path, err)
}
if strings.ToLower(filepath.Ext(entry.Name())) != ".sql" {
return nil
}

fileContents, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("reading file %q: %w", entry.Name(), err)
}

// In the future, it would make sense to split the file contents into individual DDL statements; however,
// that would require fully parsing the SQL. Naively splitting on `;` would not work because `;` can be
// used in comments, strings, and escaped identifiers.
ddl = append(ddl, ddlStatement{
stmt: string(fileContents),
file: path,
})
return nil
}); err != nil {
return nil, err
}
return ddl, nil
}

// DDLSchemaSource returns a SchemaSource that returns a schema based on the provided DDL. You must provide a tempDBFactory
// via the WithTempDbFactory option.
func DDLSchemaSource(ddl []string) SchemaSource {
func DDLSchemaSource(stmts []string) SchemaSource {
var ddl []ddlStatement
for _, stmt := range stmts {
ddl = append(ddl, ddlStatement{
stmt: stmt,
// There is no file name associated with the DDL statement.
file: ""},
)
}

return &ddlSchemaSource{ddl: ddl}
}

Expand All @@ -45,9 +114,13 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe
}
}(tempDb.ContextualCloser)

for _, stmt := range s.ddl {
if _, err := tempDb.ConnPool.ExecContext(ctx, stmt); err != nil {
return schema.Schema{}, fmt.Errorf("running DDL: %w", err)
for _, ddlStmt := range s.ddl {
if _, err := tempDb.ConnPool.ExecContext(ctx, ddlStmt.stmt); err != nil {
debugInfo := ""
if ddlStmt.file != "" {
debugInfo = fmt.Sprintf(" (from %s)", ddlStmt.file)
}
return schema.Schema{}, fmt.Errorf("running DDL%s: %w", debugInfo, err)
}
}

Expand Down

0 comments on commit 5fe8259

Please sign in to comment.