-
Notifications
You must be signed in to change notification settings - Fork 857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce SQL sanitizer allocations #2136
base: master
Are you sure you want to change the base?
Changes from all commits
afa974f
aabed18
efc2c9f
546ad2f
ee718a1
1752f7b
58d4c0c
ea1e13a
4293b25
c4c1076
39ffc8b
59d6aa8
90a77b1
47cbd8e
057937d
120c89f
da0315d
e452f80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env bash | ||
|
||
current_branch=$(git rev-parse --abbrev-ref HEAD) | ||
if [ "$current_branch" == "HEAD" ]; then | ||
current_branch=$(git rev-parse HEAD) | ||
fi | ||
|
||
restore_branch() { | ||
echo "Restoring original branch/commit: $current_branch" | ||
git checkout "$current_branch" | ||
} | ||
trap restore_branch EXIT | ||
|
||
# Check if there are uncommitted changes | ||
if ! git diff --quiet || ! git diff --cached --quiet; then | ||
echo "There are uncommitted changes. Please commit or stash them before running this script." | ||
exit 1 | ||
fi | ||
|
||
# Ensure that at least one commit argument is passed | ||
if [ "$#" -lt 1 ]; then | ||
echo "Usage: $0 <commit1> <commit2> ... <commitN>" | ||
exit 1 | ||
fi | ||
|
||
commits=("$@") | ||
benchmarks_dir=benchmarks | ||
|
||
if ! mkdir -p "${benchmarks_dir}"; then | ||
echo "Unable to create dir for benchmarks data" | ||
exit 1 | ||
fi | ||
|
||
# Benchmark results | ||
bench_files=() | ||
|
||
# Run benchmark for each listed commit | ||
for i in "${!commits[@]}"; do | ||
commit="${commits[i]}" | ||
git checkout "$commit" || { | ||
echo "Failed to checkout $commit" | ||
exit 1 | ||
} | ||
|
||
# Sanitized commmit message | ||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') | ||
|
||
# Benchmark data will go there | ||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" | ||
|
||
if ! go test -bench=. -count=10 >"$bench_file"; then | ||
echo "Benchmarking failed for commit $commit" | ||
exit 1 | ||
fi | ||
|
||
bench_files+=("$bench_file") | ||
done | ||
|
||
# go install golang.org/x/perf/cmd/benchstat[@latest] | ||
benchstat "${bench_files[@]}" | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,10 @@ import ( | |
"bytes" | ||
"encoding/hex" | ||
"fmt" | ||
"slices" | ||
"strconv" | ||
"strings" | ||
"sync" | ||
"time" | ||
"unicode/utf8" | ||
) | ||
|
@@ -24,53 +26,75 @@ type Query struct { | |
// https://github.com/jackc/pgx/issues/1380 | ||
const replacementcharacterwidth = 3 | ||
|
||
const maxBufSize = 16384 // 16 Ki | ||
|
||
var bufPool = &pool[*bytes.Buffer]{ | ||
new: func() *bytes.Buffer { | ||
return &bytes.Buffer{} | ||
}, | ||
reset: func(b *bytes.Buffer) bool { | ||
n := b.Len() | ||
b.Reset() | ||
return n < maxBufSize | ||
}, | ||
} | ||
|
||
var null = []byte("null") | ||
|
||
func (q *Query) Sanitize(args ...any) (string, error) { | ||
argUse := make([]bool, len(args)) | ||
buf := &bytes.Buffer{} | ||
buf := bufPool.get() | ||
defer bufPool.put(buf) | ||
|
||
for _, part := range q.Parts { | ||
var str string | ||
switch part := part.(type) { | ||
case string: | ||
str = part | ||
buf.WriteString(part) | ||
case int: | ||
argIdx := part - 1 | ||
|
||
var p []byte | ||
if argIdx < 0 { | ||
return "", fmt.Errorf("first sql argument must be > 0") | ||
} | ||
|
||
if argIdx >= len(args) { | ||
return "", fmt.Errorf("insufficient arguments") | ||
} | ||
|
||
// Prevent SQL injection via Line Comment Creation | ||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p | ||
buf.WriteByte(' ') | ||
|
||
arg := args[argIdx] | ||
switch arg := arg.(type) { | ||
case nil: | ||
str = "null" | ||
p = null | ||
case int64: | ||
str = strconv.FormatInt(arg, 10) | ||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) | ||
case float64: | ||
str = strconv.FormatFloat(arg, 'f', -1, 64) | ||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) | ||
case bool: | ||
str = strconv.FormatBool(arg) | ||
p = strconv.AppendBool(buf.AvailableBuffer(), arg) | ||
case []byte: | ||
str = QuoteBytes(arg) | ||
p = QuoteBytes(buf.AvailableBuffer(), arg) | ||
case string: | ||
str = QuoteString(arg) | ||
p = QuoteString(buf.AvailableBuffer(), arg) | ||
case time.Time: | ||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") | ||
p = arg.Truncate(time.Microsecond). | ||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") | ||
default: | ||
return "", fmt.Errorf("invalid arg type: %T", arg) | ||
} | ||
argUse[argIdx] = true | ||
|
||
buf.Write(p) | ||
|
||
// Prevent SQL injection via Line Comment Creation | ||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p | ||
str = " " + str + " " | ||
buf.WriteByte(' ') | ||
default: | ||
return "", fmt.Errorf("invalid Part type: %T", part) | ||
} | ||
buf.WriteString(str) | ||
} | ||
|
||
for i, used := range argUse { | ||
|
@@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) { | |
} | ||
|
||
func NewQuery(sql string) (*Query, error) { | ||
l := &sqlLexer{ | ||
src: sql, | ||
stateFn: rawState, | ||
query := &Query{} | ||
query.init(sql) | ||
|
||
return query, nil | ||
} | ||
|
||
var sqlLexerPool = &pool[*sqlLexer]{ | ||
new: func() *sqlLexer { | ||
return &sqlLexer{} | ||
}, | ||
reset: func(sl *sqlLexer) bool { | ||
*sl = sqlLexer{} | ||
return true | ||
}, | ||
} | ||
|
||
func (q *Query) init(sql string) { | ||
parts := q.Parts[:0] | ||
if parts == nil { | ||
// dirty, but fast heuristic to preallocate for ~90% usecases | ||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 | ||
parts = make([]Part, 0, n) | ||
} | ||
|
||
l := sqlLexerPool.get() | ||
defer sqlLexerPool.put(l) | ||
|
||
l.src = sql | ||
l.stateFn = rawState | ||
l.parts = parts | ||
|
||
for l.stateFn != nil { | ||
l.stateFn = l.stateFn(l) | ||
} | ||
|
||
query := &Query{Parts: l.parts} | ||
|
||
return query, nil | ||
q.Parts = l.parts | ||
} | ||
|
||
func QuoteString(str string) string { | ||
return "'" + strings.ReplaceAll(str, "'", "''") + "'" | ||
func QuoteString(dst []byte, str string) []byte { | ||
const quote = '\'' | ||
|
||
// Preallocate space for the worst case scenario | ||
dst = slices.Grow(dst, len(str)*2+2) | ||
|
||
// Add opening quote | ||
dst = append(dst, quote) | ||
|
||
// Iterate through the string without allocating | ||
for i := 0; i < len(str); i++ { | ||
if str[i] == quote { | ||
dst = append(dst, quote, quote) | ||
} else { | ||
dst = append(dst, str[i]) | ||
} | ||
} | ||
|
||
// Add closing quote | ||
dst = append(dst, quote) | ||
|
||
return dst | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is purely a style nit, but I don't like reslicing for these types of functions because it's not idiomatic and hard to follow. I took the above func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
func QuoteBytes(buf []byte) string { | ||
return `'\x` + hex.EncodeToString(buf) + "'" | ||
func QuoteBytes(dst, buf []byte) []byte { | ||
if len(buf) == 0 { | ||
return append(dst, `'\x'`...) | ||
} | ||
|
||
// Calculate required length | ||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 | ||
|
||
// Ensure dst has enough capacity | ||
if cap(dst)-len(dst) < requiredLen { | ||
newDst := make([]byte, len(dst), len(dst)+requiredLen) | ||
copy(newDst, dst) | ||
dst = newDst | ||
} | ||
|
||
// Record original length and extend slice | ||
origLen := len(dst) | ||
dst = dst[:origLen+requiredLen] | ||
|
||
// Add prefix | ||
dst[origLen] = '\'' | ||
dst[origLen+1] = '\\' | ||
dst[origLen+2] = 'x' | ||
|
||
// Encode bytes directly into dst | ||
hex.Encode(dst[origLen+3:len(dst)-1], buf) | ||
|
||
// Add suffix | ||
dst[len(dst)-1] = '\'' | ||
|
||
return dst | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was able to measure an improvement by optimizing this function: func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
type sqlLexer struct { | ||
|
@@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn { | |
} | ||
} | ||
|
||
var queryPool = &pool[*Query]{ | ||
new: func() *Query { | ||
return &Query{} | ||
}, | ||
reset: func(q *Query) bool { | ||
n := len(q.Parts) | ||
q.Parts = q.Parts[:0] | ||
return n < 64 // drop too large queries | ||
}, | ||
} | ||
|
||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args | ||
// as necessary. This function is only safe when standard_conforming_strings is | ||
// on. | ||
func SanitizeSQL(sql string, args ...any) (string, error) { | ||
query, err := NewQuery(sql) | ||
if err != nil { | ||
return "", err | ||
} | ||
query := queryPool.get() | ||
query.init(sql) | ||
defer queryPool.put(query) | ||
|
||
return query.Sanitize(args...) | ||
} | ||
|
||
type pool[E any] struct { | ||
p sync.Pool | ||
new func() E | ||
reset func(E) bool | ||
} | ||
|
||
func (pool *pool[E]) get() E { | ||
v, ok := pool.p.Get().(E) | ||
if !ok { | ||
v = pool.new() | ||
} | ||
|
||
return v | ||
} | ||
|
||
func (p *pool[E]) put(v E) { | ||
if p.reset(v) { | ||
p.p.Put(v) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// sanitize_benchmark_test.go | ||
package sanitize_test | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/jackc/pgx/v5/internal/sanitize" | ||
) | ||
|
||
var benchmarkSanitizeResult string | ||
|
||
const benchmarkQuery = "" + | ||
`SELECT * | ||
FROM "water_containers" | ||
WHERE NOT "id" = $1 -- int64 | ||
AND "tags" NOT IN $2 -- nil | ||
AND "volume" > $3 -- float64 | ||
AND "transportable" = $4 -- bool | ||
AND position($5 IN "sign") -- bytes | ||
AND "label" LIKE $6 -- string | ||
AND "created_at" > $7; -- time.Time` | ||
|
||
var benchmarkArgs = []any{ | ||
int64(12345), | ||
nil, | ||
float64(500), | ||
true, | ||
[]byte("8BADF00D"), | ||
"kombucha's han'dy awokowa", | ||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC), | ||
} | ||
|
||
func BenchmarkSanitize(b *testing.B) { | ||
query, err := sanitize.NewQuery(benchmarkQuery) | ||
if err != nil { | ||
b.Fatalf("failed to create query: %v", err) | ||
} | ||
|
||
b.ResetTimer() | ||
b.ReportAllocs() | ||
|
||
for i := 0; i < b.N; i++ { | ||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...) | ||
if err != nil { | ||
b.Fatalf("failed to sanitize query: %v", err) | ||
} | ||
} | ||
} | ||
|
||
var benchmarkNewSQLResult string | ||
|
||
func BenchmarkSanitizeSQL(b *testing.B) { | ||
b.ReportAllocs() | ||
var err error | ||
for i := 0; i < b.N; i++ { | ||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...) | ||
if err != nil { | ||
b.Fatalf("failed to sanitize SQL: %v", err) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you prefix with a small comment:
# go install golang.org/x/perf/cmd/benchstat@latest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done