Skip to content

Commit

Permalink
StrictNamedArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
Fazt01 authored and jackc committed Mar 16, 2024
1 parent 1b6227a commit b6e5548
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 16 deletions.
58 changes: 42 additions & 16 deletions named_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgx

import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
Expand All @@ -21,6 +22,34 @@ type NamedArgs map[string]any

// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}

// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any

// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}

type namedArg string

type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any

nameToOrdinal map[namedArg]int
}

type stateFn func(*sqlLexer) stateFn

func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
Expand All @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar

newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)]
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}

return sb.String(), newArgs, nil
}

type namedArg string

type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}

nameToOrdinal map[namedArg]int
return sb.String(), newArgs, nil
}

type stateFn func(*sqlLexer) stateFn

func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
Expand Down
58 changes: 58 additions & 0 deletions named_args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: "@missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},

// test comments and quotes
} {
Expand All @@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}

func TestStrictNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()

for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: "@all @matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: "@missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}

0 comments on commit b6e5548

Please sign in to comment.