Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StrictNamedArgs #1941

Merged
merged 1 commit into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions named_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgx

import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
Expand All @@ -21,6 +22,34 @@ type NamedArgs map[string]any

// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}

// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any

// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}

type namedArg string

type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any

nameToOrdinal map[namedArg]int
}

type stateFn func(*sqlLexer) stateFn

func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
Expand All @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar

newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)]
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}

return sb.String(), newArgs, nil
}

type namedArg string

type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}

nameToOrdinal map[namedArg]int
return sb.String(), newArgs, nil
}

type stateFn func(*sqlLexer) stateFn

func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
Expand Down
58 changes: 58 additions & 0 deletions named_args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: "@missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},

// test comments and quotes
} {
Expand All @@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}

func TestStrictNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()

for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: "@all @matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: "@missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}
Loading