Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dir schema source errors #172

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 4 additions & 37 deletions cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"database/sql"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -223,19 +221,12 @@ func parsePlanConfig(p planFlags) (planConfig, error) {

func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
if len(p.schemaDirs) > 0 {
var ddl []string
// Ordering of execution of schema SQL can be guaranteed by:
// - Splitting across multiple directories and using multiple schema dir flags
// - Relying on lexical order of SQL files
for _, schemaDir := range p.schemaDirs {
stmts, err := getDDLFromPath(schemaDir)
return func() (diff.SchemaSource, io.Closer, error) {
schemaSource, err := diff.DirSchemaSource(p.schemaDirs)
if err != nil {
return nil, fmt.Errorf("getting DDL from path %q: %w", schemaDir, err)
return nil, nil, err
}
ddl = append(ddl, stmts...)
}
return func() (diff.SchemaSource, io.Closer, error) {
return diff.DDLSchemaSource(ddl), nil, nil
return schemaSource, nil, nil
}, nil
}

Expand Down Expand Up @@ -434,30 +425,6 @@ func applyPlanModifiers(
return plan, nil
}

// getDDLFromPath reads all .sql files under the given path (including sub-directories) and returns the DDL
// in lexical order.
func getDDLFromPath(path string) ([]string, error) {
var ddl []string
if err := filepath.Walk(path, func(path string, entry os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walking path %q: %w", path, err)
}
if strings.ToLower(filepath.Ext(entry.Name())) != ".sql" {
return nil
}

if stmts, err := os.ReadFile(path); err != nil {
return fmt.Errorf("reading file %q: %w", entry.Name(), err)
} else {
ddl = append(ddl, string(stmts))
}
return nil
}); err != nil {
return nil, err
}
return ddl, nil
}

func planToPrettyS(plan diff.Plan) string {
sb := strings.Builder{}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)

func databaseSchemaSourcePlanFactory(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
func databaseSchemaSourcePlan(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
newSchemaDb, err := tempDbFactory.Create(ctx)
if err != nil {
return diff.Plan{}, fmt.Errorf("creating temp database: %w", err)
Expand Down Expand Up @@ -38,8 +38,29 @@ func databaseSchemaSourcePlanFactory(ctx context.Context, connPool sqldb.Queryab
return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
}

func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
return func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (_ diff.Plan, retErr error) {
// Clone the opts so we don't modify the original.
opts = append([]diff.PlanOpt(nil), opts...)
opts = append(opts, diff.WithTempDbFactory(tempDbFactory))

if len(newSchemaDDL) != 0 {
panic("newSchemaDDL should be empty for dir schema sources")
}

schemaSource, err := diff.DirSchemaSource(schemaDirs)
if err != nil {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}

return diff.Generate(ctx, connPool, schemaSource, opts...)
}
}

var databaseSchemaSourceTestCases = []acceptanceTestCase{
{
planFactory: databaseSchemaSourcePlan,

name: "Drop partitioned table, Add partitioned table with local keys",
oldSchemaDDL: []string{
`
Expand Down Expand Up @@ -119,8 +140,48 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeDeletesData,
},
},
{
planFactory: dirSchemaSourcePlanFactory([]string{"testdata/dirsrc_happy_path/schema_0", "testdata/dirsrc_happy_path/schema_1"}),

name: "Dir src - happy path",
oldSchemaDDL: []string{
`
CREATE TABLE foobar( );
`,
},

expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeIndexBuild,
},
expectedDBSchemaDDL: []string{
`
CREATE TYPE color AS ENUM ('red', 'green', 'blue');
CREATE TABLE foobar(
color color,
id varchar(255) PRIMARY KEY
);
CREATE TABLE foobar_fk(
id TEXT REFERENCES foobar(id)
);
CREATE TABLE fizzbuzz(
id TEXT,
primary_color color,
other_color color
); `,
},
},
{
planFactory: dirSchemaSourcePlanFactory([]string{"testdata/dirsrc_invalid_sql/schema_0"}),

name: "Dir src - invalid sql",
oldSchemaDDL: []string{
`
CREATE TABLE foobar( );
`,
},

planFactory: databaseSchemaSourcePlanFactory,
expectedPlanErrorContains: "testdata/dirsrc_invalid_sql/schema_0/1.sql",
},
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TYPE color AS ENUM ('red', 'green', 'blue');
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE foobar (
id varchar(255) PRIMARY KEY,
color color
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE foobar_fk (
id TEXT REFERENCES foobar (id)
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE fizzbuzz (
id TEXT,
primary_color COLOR,
other_color COLOR
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE foobar (
id SERIAL PRIMARY KEY
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE TABLE fizzbuzz (
id TEXT REFERENCES non_existent_table (id)
);
85 changes: 79 additions & 6 deletions pkg/diff/schema_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package diff
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/stripe/pg-schema-diff/internal/schema"
"github.com/stripe/pg-schema-diff/pkg/log"
Expand All @@ -20,13 +23,79 @@ type SchemaSource interface {
GetSchema(ctx context.Context, deps schemaSourcePlanDeps) (schema.Schema, error)
}

type ddlSchemaSource struct {
ddl []string
type (
ddlStatement struct {
// stmt is the DDL statement to run.
stmt string
// file is an optional field that can be used to store the file name from which the DDL was read.
file string
}

ddlSchemaSource struct {
ddl []ddlStatement
}
)

// DirSchemaSource returns a SchemaSource that returns a schema based on the provided directories. You must provide a tempDBFactory
// via the WithTempDbFactory option.
func DirSchemaSource(dirs []string) (SchemaSource, error) {
var ddl []ddlStatement
for _, dir := range dirs {
stmts, err := getDDLFromPath(dir)
if err != nil {
return &ddlSchemaSource{}, err
}
ddl = append(ddl, stmts...)

}
return &ddlSchemaSource{
ddl: ddl,
}, nil
}

// getDDLFromPath reads all .sql files under the given path (including sub-directories) and returns the DDL
// in lexical order.
func getDDLFromPath(path string) ([]ddlStatement, error) {
var ddl []ddlStatement
if err := filepath.Walk(path, func(path string, entry os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walking path %q: %w", path, err)
}
if strings.ToLower(filepath.Ext(entry.Name())) != ".sql" {
return nil
}

fileContents, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("reading file %q: %w", entry.Name(), err)
}

// In the future, it would make sense to split the file contents into individual DDL statements; however,
// that would require fully parsing the SQL. Naively splitting on `;` would not work because `;` can be
// used in comments, strings, and escaped identifiers.
ddl = append(ddl, ddlStatement{
stmt: string(fileContents),
file: path,
})
return nil
}); err != nil {
return nil, err
}
return ddl, nil
}

// DDLSchemaSource returns a SchemaSource that returns a schema based on the provided DDL. You must provide a tempDBFactory
// via the WithTempDbFactory option.
func DDLSchemaSource(ddl []string) SchemaSource {
func DDLSchemaSource(stmts []string) SchemaSource {
var ddl []ddlStatement
for _, stmt := range stmts {
ddl = append(ddl, ddlStatement{
stmt: stmt,
// There is no file name associated with the DDL statement.
file: ""},
)
}

return &ddlSchemaSource{ddl: ddl}
}

Expand All @@ -45,9 +114,13 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe
}
}(tempDb.ContextualCloser)

for _, stmt := range s.ddl {
if _, err := tempDb.ConnPool.ExecContext(ctx, stmt); err != nil {
return schema.Schema{}, fmt.Errorf("running DDL: %w", err)
for _, ddlStmt := range s.ddl {
if _, err := tempDb.ConnPool.ExecContext(ctx, ddlStmt.stmt); err != nil {
debugInfo := ""
if ddlStmt.file != "" {
debugInfo = fmt.Sprintf(" (from %s)", ddlStmt.file)
}
return schema.Schema{}, fmt.Errorf("running DDL%s: %w", debugInfo, err)
}
}

Expand Down
Loading