diff --git a/cmd/pg-schema-diff/datastructs.go b/cmd/pg-schema-diff/datastructs.go new file mode 100644 index 0000000..1a5de84 --- /dev/null +++ b/cmd/pg-schema-diff/datastructs.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "sort" +) + +func mustGetAndDeleteKey(m map[string]string, key string) (string, error) { + val, ok := m[key] + if !ok { + return "", fmt.Errorf("could not find key %q", key) + } + delete(m, key) + return val, nil +} + +func keys(m map[string]string) []string { + var vals []string + for k := range m { + vals = append(vals, k) + } + sort.Strings(vals) + return vals +} diff --git a/cmd/pg-schema-diff/datastructs_test.go b/cmd/pg-schema-diff/datastructs_test.go new file mode 100644 index 0000000..26479d9 --- /dev/null +++ b/cmd/pg-schema-diff/datastructs_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeys(t *testing.T) { + for _, tt := range []struct { + name string + m map[string]string + + want []string + }{ + { + name: "nil map", + + want: nil, + }, + { + name: "empty map", + + want: nil, + }, + { + name: "filled map", + m: map[string]string{ + // Use an arbitrary order + "key2": "value2", + "key3": "value3", + "key1": "value1", + }, + + want: []string{"key1", "key2", "key3"}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, keys(tt.m)) + }) + } +} diff --git a/cmd/pg-schema-diff/flags.go b/cmd/pg-schema-diff/flags.go index 128acfc..7308352 100644 --- a/cmd/pg-schema-diff/flags.go +++ b/cmd/pg-schema-diff/flags.go @@ -1,6 +1,10 @@ package main import ( + "fmt" + "strings" + + "github.com/go-logfmt/logfmt" "github.com/jackc/pgx/v4" "github.com/spf13/cobra" "github.com/stripe/pg-schema-diff/pkg/log" @@ -27,3 +31,24 @@ func parseConnConfig(c connFlags, logger log.Logger) (*pgx.ConnConfig, error) { return pgx.ParseConfig(c.dsn) } + +// LogFmtToMap parses all LogFmt key/value pairs from the provided string into a +// map. +// +// All records are scanned. If a duplicate key is found, an error is returned. +func LogFmtToMap(logFmt string) (map[string]string, error) { + logMap := make(map[string]string) + decoder := logfmt.NewDecoder(strings.NewReader(logFmt)) + for decoder.ScanRecord() { + for decoder.ScanKeyval() { + if _, ok := logMap[string(decoder.Key())]; ok { + return nil, fmt.Errorf("duplicate key %q in logfmt", string(decoder.Key())) + } + logMap[string(decoder.Key())] = string(decoder.Value()) + } + } + if decoder.Err() != nil { + return nil, decoder.Err() + } + return logMap, nil +} diff --git a/cmd/pg-schema-diff/flags_test.go b/cmd/pg-schema-diff/flags_test.go new file mode 100644 index 0000000..351e5c7 --- /dev/null +++ b/cmd/pg-schema-diff/flags_test.go @@ -0,0 +1,72 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLogFmtToMap(t *testing.T) { + type args struct { + logFmt string + } + tests := []struct { + name string + args args + want map[string]string + wantErr bool + }{ + { + name: "empty string", + args: args{ + logFmt: "", + }, + want: map[string]string{}, + wantErr: false, + }, + { + name: "single key value pair", + args: args{ + logFmt: "key=value", + }, + want: map[string]string{"key": "value"}, + wantErr: false, + }, + { + name: "multiple key value pairs", + args: args{ + logFmt: "key1=value1 key2=value2", + }, + want: map[string]string{"key1": "value1", "key2": "value2"}, + wantErr: false, + }, + { + name: "duplicate key", + args: args{ + logFmt: "key=value1 key=value2", + }, + want: nil, + wantErr: true, + }, + { + name: "multiple records", + args: args{ + logFmt: "key1=value1 key2=value2\nkey3=value3", + }, + want: map[string]string{"key1": "value1", "key2": "value2", "key3": "value3"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := LogFmtToMap(tt.args.logFmt) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go index 939c3fa..2afd44d 100644 --- a/cmd/pg-schema-diff/plan_cmd.go +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -21,24 +21,14 @@ import ( const ( defaultMaxConnections = 5 -) -var ( - // Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration - // We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is - // all characters greedily up to the rightmost "=" - statementTimeoutModifierRegex = regexp.MustCompile(`^(?P.+)=(?P.+)$`) - regexSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("regex") - durationSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("duration") - - // Match arguments in the format "index duration:statement" where duration is any duration valid in - // time.ParseDuration. In order to prevent matching on ":" in the duration, limit the character to just letters - // and numbers. To keep the regex simple, we won't bother matching on a more specific pattern for durations. - // time.ParseDuration can handle the complexity of parsing invalid durations - insertStatementRegex = regexp.MustCompile(`^(?P\d+) (?P[a-zA-Z0-9\.]+):(?P.+?);?$`) - indexInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("index") - durationInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("duration") - ddlInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("ddl") + patternTimeoutModifierKey = "pattern" + timeoutTimeoutModifierKey = "timeout" + + indexInsertStatementKey = "index" + statementInsertStatementKey = "statement" + statementTimeoutInsertStatementKey = "timeout" + lockTimeoutInsertStatementKey = "lock_timeout" ) func buildPlanCmd() *cobra.Command { @@ -100,15 +90,16 @@ type ( insertStatements []string } - timeoutModifiers struct { + timeoutModifier struct { regex *regexp.Regexp timeout time.Duration } insertStatement struct { - ddl string - index int - timeout time.Duration + ddl string + index int + timeout time.Duration + lockTimeout time.Duration } schemaSourceFactory func() (diff.SchemaSource, io.Closer, error) @@ -117,8 +108,8 @@ type ( schemaSourceFactory schemaSourceFactory opts []diff.PlanOpt - statementTimeoutModifiers []timeoutModifiers - lockTimeoutModifiers []timeoutModifiers + statementTimeoutModifiers []timeoutModifier + lockTimeoutModifiers []timeoutModifier insertStatements []insertStatement } ) @@ -132,9 +123,16 @@ func createPlanFlags(cmd *cobra.Command) *planFlags { timeoutModifierFlagVar(cmd, &flags.statementTimeoutModifiers, "statement", "t") timeoutModifierFlagVar(cmd, &flags.lockTimeoutModifiers, "lock", "l") - cmd.Flags().StringArrayVarP(&flags.insertStatements, "insert-statement", "s", nil, - "_: values. Will insert the statement at the index in the "+ - "generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''") + cmd.Flags().StringArrayVarP( + &flags.insertStatements, + "insert-statement", "s", nil, + fmt.Sprintf( + "'%s= %s=\"\" %s= %s=' values. Will insert the statement at the index in the "+ + "generated plan. This follows normal insert semantics. Example: -s '%s=1 %s=\"SELECT pg_sleep(5)\" %s=5s %s=1s'", + indexInsertStatementKey, statementInsertStatementKey, statementTimeoutInsertStatementKey, lockTimeoutInsertStatementKey, + indexInsertStatementKey, statementInsertStatementKey, statementTimeoutInsertStatementKey, lockTimeoutInsertStatementKey, + ), + ) return flags } @@ -156,10 +154,14 @@ func schemaFlagsVar(cmd *cobra.Command, p *schemaFlags) { func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, shorthand string) { flagName := fmt.Sprintf("%s-timeout-modifier", timeoutType) - desc := fmt.Sprintf("regex=timeout key-value pairs, where if a statement matches the regex, the statement "+ - "will be modified to have the %s timeout. If multiple regexes match, the latest regex will take priority. "+ - "Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'", timeoutType) - cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, desc) + description := fmt.Sprintf("list of '%s=\"\" %s=', where if a statement matches "+ + "the regex, the statement will have the target %s timeout. If multiple regexes match, the latest regex will "+ + "take priority. Example: -t '%s=\"CREATE TABLE\" %s=5m'", + patternTimeoutModifierKey, timeoutTimeoutModifierKey, + timeoutType, + patternTimeoutModifierKey, timeoutTimeoutModifierKey, + ) + cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, description) } func parsePlanConfig(p planFlags) (planConfig, error) { @@ -168,7 +170,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) { return planConfig{}, err } - var statementTimeoutModifiers []timeoutModifiers + var statementTimeoutModifiers []timeoutModifier for _, s := range p.statementTimeoutModifiers { stm, err := parseTimeoutModifier(s) if err != nil { @@ -177,7 +179,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) { statementTimeoutModifiers = append(statementTimeoutModifiers, stm) } - var lockTimeoutModifiers []timeoutModifiers + var lockTimeoutModifiers []timeoutModifier for _, s := range p.lockTimeoutModifiers { ltm, err := parseTimeoutModifier(s) if err != nil { @@ -239,55 +241,94 @@ func parseSchemaConfig(p schemaFlags) []diff.PlanOpt { } } -func parseTimeoutModifier(val string) (timeoutModifiers, error) { - submatches := statementTimeoutModifierRegex.FindStringSubmatch(val) - if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex { - return timeoutModifiers{}, fmt.Errorf("could not parse regex and duration from arg. expected to be in the format of " + - "'Some.*Regex='. Example durations include: 2s, 5m, 10.5h") +// parseTimeoutModifier attempts to parse an option representing a statement timeout modifier in the +// form of regex=duration where duration could be a decimal number and ends with a unit +func parseTimeoutModifier(val string) (timeoutModifier, error) { + fm, err := LogFmtToMap(val) + if err != nil { + return timeoutModifier{}, fmt.Errorf("could not parse %q into logfmt: %w", val, err) + } + + regexStr, err := mustGetAndDeleteKey(fm, patternTimeoutModifierKey) + if err != nil { + return timeoutModifier{}, err + } + + timeoutStr, err := mustGetAndDeleteKey(fm, timeoutTimeoutModifierKey) + if err != nil { + return timeoutModifier{}, err + } + + if len(fm) > 0 { + return timeoutModifier{}, fmt.Errorf("unknown keys %s", keys(fm)) } - regexStr := submatches[regexSTMRegexIndex] - durationStr := submatches[durationSTMRegexIndex] - regex, err := regexp.Compile(regexStr) + duration, err := time.ParseDuration(timeoutStr) if err != nil { - return timeoutModifiers{}, fmt.Errorf("regex could not be compiled from %q: %w", regexStr, err) + return timeoutModifier{}, fmt.Errorf("duration could not be parsed from %q: %w", timeoutStr, err) } - duration, err := time.ParseDuration(durationStr) + re, err := regexp.Compile(regexStr) if err != nil { - return timeoutModifiers{}, fmt.Errorf("duration could not be parsed from %q: %w", durationStr, err) + return timeoutModifier{}, fmt.Errorf("pattern regex could not be compiled from %q: %w", regexStr, err) } - return timeoutModifiers{ - regex: regex, + return timeoutModifier{ + regex: re, timeout: duration, }, nil } func parseInsertStatementStr(val string) (insertStatement, error) { - submatches := insertStatementRegex.FindStringSubmatch(val) - if len(submatches) <= indexInsertStatementRegexIndex || - len(submatches) <= durationInsertStatementRegexIndex || - len(submatches) <= ddlInsertStatementRegexIndex { - return insertStatement{}, fmt.Errorf("could not parse index, duration, and statement from arg. expected to be in the " + - "format of ' :'. Example durations include: 2s, 5m, 10.5h") - } - indexStr := submatches[indexInsertStatementRegexIndex] + fm, err := LogFmtToMap(val) + if err != nil { + return insertStatement{}, fmt.Errorf("could not parse into logfmt: %w", err) + } + + indexStr, err := mustGetAndDeleteKey(fm, indexInsertStatementKey) + if err != nil { + return insertStatement{}, err + } + + statementStr, err := mustGetAndDeleteKey(fm, statementInsertStatementKey) + if err != nil { + return insertStatement{}, err + } + + statementTimeoutStr, err := mustGetAndDeleteKey(fm, statementTimeoutInsertStatementKey) + if err != nil { + return insertStatement{}, err + } + + lockTimeoutStr, err := mustGetAndDeleteKey(fm, lockTimeoutInsertStatementKey) + if err != nil { + return insertStatement{}, err + } + + if len(fm) > 0 { + return insertStatement{}, fmt.Errorf("unknown keys %s", keys(fm)) + } + index, err := strconv.Atoi(indexStr) if err != nil { - return insertStatement{}, fmt.Errorf("could not parse index (an int) from \"%q\"", indexStr) + return insertStatement{}, fmt.Errorf("index could not be parsed from %q: %w", indexStr, err) } - durationStr := submatches[durationInsertStatementRegexIndex] - duration, err := time.ParseDuration(durationStr) + statementTimeout, err := time.ParseDuration(statementTimeoutStr) if err != nil { - return insertStatement{}, fmt.Errorf("duration could not be parsed from \"%q\": %w", durationStr, err) + return insertStatement{}, fmt.Errorf("statement timeout duration could not be parsed from %q: %w", statementTimeoutStr, err) + } + + lockTimeout, err := time.ParseDuration(lockTimeoutStr) + if err != nil { + return insertStatement{}, fmt.Errorf("lock timeout duration could not be parsed from %q: %w", lockTimeoutStr, err) } return insertStatement{ - index: index, - ddl: submatches[ddlInsertStatementRegexIndex], - timeout: duration, + index: index, + ddl: statementStr, + timeout: statementTimeout, + lockTimeout: lockTimeout, }, nil } @@ -357,16 +398,16 @@ func applyPlanModifiers( for _, is := range config.insertStatements { var err error plan, err = plan.InsertStatement(is.index, diff.Statement{ - DDL: is.ddl, - Timeout: is.timeout, + DDL: is.ddl, + Timeout: is.timeout, + LockTimeout: is.lockTimeout, Hazards: []diff.MigrationHazard{{ Type: diff.MigrationHazardTypeIsUserGenerated, Message: "This statement is user-generated", }}, }) if err != nil { - return diff.Plan{}, fmt.Errorf("inserting statement %q with timeout %s at index %d: %w", - is.ddl, is.timeout, is.index, err) + return diff.Plan{}, fmt.Errorf("inserting %+v: %w", is, err) } } return plan, nil diff --git a/cmd/pg-schema-diff/plan_cmd_test.go b/cmd/pg-schema-diff/plan_cmd_test.go index 0a7d178..b323411 100644 --- a/cmd/pg-schema-diff/plan_cmd_test.go +++ b/cmd/pg-schema-diff/plan_cmd_test.go @@ -1,6 +1,7 @@ package main import ( + "regexp" "testing" "time" @@ -8,54 +9,46 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseStatementTimeoutModifierStr(t *testing.T) { +func TestParseTimeoutModifierStr(t *testing.T) { for _, tc := range []struct { opt string `explicit:"always"` - expectedRegexStr string - expectedTimeout time.Duration + expected timeoutModifier expectedErrContains string }{ { - opt: "normal duration=5m", - expectedRegexStr: "normal duration", - expectedTimeout: 5 * time.Minute, - }, - { - opt: "some regex with a duration ending in a period=5.h", - expectedRegexStr: "some regex with a duration ending in a period", - expectedTimeout: 5 * time.Hour, - }, - { - opt: " starts with spaces than has a *=5.5m", - expectedRegexStr: " starts with spaces than has a *", - expectedTimeout: time.Minute*5 + 30*time.Second, + opt: `pattern="normal \"pattern\"" timeout=5m`, + expected: timeoutModifier{ + regex: regexp.MustCompile(`normal "pattern"`), + timeout: 5 * time.Minute, + }, }, { - opt: "has a valid opt in the regex something=5.5m in the regex =15s", - expectedRegexStr: "has a valid opt in the regex something=5.5m in the regex ", - expectedTimeout: 15 * time.Second, + opt: `pattern=unquoted-no-space-pattern timeout=5m`, + expected: timeoutModifier{ + regex: regexp.MustCompile("unquoted-no-space-pattern"), + timeout: 5 * time.Minute, + }, }, { - opt: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration=15m1ms2us10ns", - expectedRegexStr: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration", - expectedTimeout: 15*time.Minute + 1*time.Millisecond + 2*time.Microsecond + 10*time.Nanosecond, + opt: "timeout=15m", + expectedErrContains: "could not find key", }, { - opt: "=5m", - expectedErrContains: "could not parse regex and duration from arg", + opt: `pattern="some pattern"`, + expectedErrContains: "could not find key", }, { - opt: "15m", - expectedErrContains: "could not parse regex and duration from arg", + opt: `pattern="normal" timeout=5m some-unknown-key=5m`, + expectedErrContains: "unknown keys", }, { - opt: "someregex;15m", - expectedErrContains: "could not parse regex and duration from arg", + opt: `pattern="some-pattern" timeout=invalid-duration`, + expectedErrContains: "duration could not be parsed", }, { - opt: "someregex=invalid duration5s", - expectedErrContains: "duration could not be parsed", + opt: `pattern="some-invalid-pattern-[" timeout=5m`, + expectedErrContains: "pattern regex could not be compiled", }, } { t.Run(tc.opt, func(t *testing.T) { @@ -65,8 +58,7 @@ func TestParseStatementTimeoutModifierStr(t *testing.T) { return } require.NoError(t, err) - assert.Equal(t, tc.expectedRegexStr, modifier.regex.String()) - assert.Equal(t, tc.expectedTimeout, modifier.timeout) + assert.Equal(t, tc.expected, modifier) }) } } @@ -78,32 +70,41 @@ func TestParseInsertStatementStr(t *testing.T) { expectedErrContains string }{ { - opt: "1 0h5.1m:SELECT * FROM :TABLE:0_5m:something", + opt: `index=1 statement="SELECT * FROM \"foobar\"" timeout=5m6s lock_timeout=1m11s`, expectedInsertStmt: insertStatement{ - index: 1, - ddl: "SELECT * FROM :TABLE:0_5m:something", - timeout: 5*time.Minute + 6*time.Second, + index: 1, + ddl: `SELECT * FROM "foobar"`, + timeout: 5*time.Minute + 6*time.Second, + lockTimeout: 1*time.Minute + 11*time.Second, }, }, { - opt: "0 100ms:SELECT 1; SELECT * FROM something;", - expectedInsertStmt: insertStatement{ - index: 0, - ddl: "SELECT 1; SELECT * FROM something", - timeout: 100 * time.Millisecond, - }, + opt: "statement=no-index timeout=5m6s lock_timeout=1m11s", + expectedErrContains: "could not find key", }, { - opt: " 5s:No index", - expectedErrContains: "could not parse", + opt: "index=0 timeout=5m6s lock_timeout=1m11s", + expectedErrContains: "could not find key", }, { - opt: "0 5g:Invalid duration", - expectedErrContains: "duration could not be parsed", + opt: "index=0 statement=no-timeout lock_timeout=1m11s", + expectedErrContains: "could not find key", + }, + { + opt: "index=0 statement=no-lock-timeout-timeout timeout=5m6s", + expectedErrContains: "could not find key", + }, + { + opt: "index=not-an-int statement=some-statement timeout=5m6s lock_timeout=1m11s", + expectedErrContains: "index could not be parsed", + }, + { + opt: "index=0 statement=some-statement timeout=invalid-duration lock_timeout=1m11s", + expectedErrContains: "statement timeout duration could not be parsed", }, { - opt: "0 5s:", - expectedErrContains: "could not parse", + opt: "index=0 statement=some-statement timeout=5m6s lock_timeout=invalid-duration", + expectedErrContains: "lock timeout duration could not be parsed", }, } { t.Run(tc.opt, func(t *testing.T) { diff --git a/go.mod b/go.mod index cbcd2ca..97566f2 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/stripe/pg-schema-diff go 1.18 require ( + github.com/go-logfmt/logfmt v0.6.0 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.14.0 diff --git a/go.sum b/go.sum index 555016d..6321891 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=