diff --git a/named_args.go b/named_args.go index 8367fc63a..c88991ee4 100644 --- a/named_args.go +++ b/named_args.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "fmt" "strconv" "strings" "unicode/utf8" @@ -21,6 +22,34 @@ type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs = make([]any, len(l.nameToOrdinal)) for name, ordinal := range l.nameToOrdinal { - newArgs[ordinal-1] = na[string(name)] + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } } - return sb.String(), newArgs, nil -} - -type namedArg string - -type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []any + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } - nameToOrdinal map[namedArg]int + return sb.String(), newArgs, nil } -type stateFn func(*sqlLexer) stateFn - func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) diff --git a/named_args_test.go b/named_args_test.go index 49ac817da..8cab2f4d2 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) { where id = $1;`, expectedArgs: []any{int32(42)}, }, + { + sql: "extra provided argument", + namedArgs: pgx.NamedArgs{"extra": int32(1)}, + expectedSQL: "extra provided argument", + expectedArgs: []any{}, + }, + { + sql: "@missing argument", + namedArgs: pgx.NamedArgs{}, + expectedSQL: "$1 argument", + expectedArgs: []any{nil}, + }, // test comments and quotes } { @@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) { assert.Equalf(t, tt.expectedArgs, args, "%d", i) } } + +func TestStrictNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + namedArgs pgx.StrictNamedArgs + expectedSQL string + expectedArgs []any + isExpectedError bool + }{ + { + sql: "no arguments", + namedArgs: pgx.StrictNamedArgs{}, + expectedSQL: "no arguments", + expectedArgs: []any{}, + isExpectedError: false, + }, + { + sql: "@all @matches", + namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, + expectedSQL: "$1 $2", + expectedArgs: []any{int32(1), int32(2)}, + isExpectedError: false, + }, + { + sql: "extra provided argument", + namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, + isExpectedError: true, + }, + { + sql: "@missing argument", + namedArgs: pgx.StrictNamedArgs{}, + isExpectedError: true, + }, + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) + if tt.isExpectedError { + assert.Errorf(t, err, "%d", i) + } else { + require.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } + } +}