Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce SQL sanitizer allocations #2136

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())

require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
}
60 changes: 60 additions & 0 deletions internal/sanitize/benchmmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash

current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi

restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT

# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi

# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi

commits=("$@")
benchmarks_dir=benchmarks

if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi

# Benchmark results
bench_files=()

# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}

# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')

# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"

if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi

bench_files+=("$bench_file")
done

# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you prefix with a small comment: # go install golang.org/x/perf/cmd/benchstat@latest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

183 changes: 156 additions & 27 deletions internal/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
Expand All @@ -24,53 +26,75 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

const maxBufSize = 16384 // 16 Ki

var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}

var null = []byte("null")

func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
buf := bufPool.get()
defer bufPool.put(buf)

for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
buf.WriteString(part)
case int:
argIdx := part - 1

var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}

if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')

arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
p = null
case int64:
str = strconv.FormatInt(arg, 10)
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
str = QuoteBytes(arg)
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
str = QuoteString(arg)
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true

buf.Write(p)

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}

for i, used := range argUse {
Expand All @@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}

func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
query := &Query{}
query.init(sql)

return query, nil
}

var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}

func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}

l := sqlLexerPool.get()
defer sqlLexerPool.put(l)

l.src = sql
l.stateFn = rawState
l.parts = parts

for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}

query := &Query{Parts: l.parts}

return query, nil
q.Parts = l.parts
}

func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
func QuoteString(dst []byte, str string) []byte {
const quote = '\''

// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)

// Add opening quote
dst = append(dst, quote)

// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}

// Add closing quote
dst = append(dst, quote)

return dst
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is purely a style nit, but I don't like reslicing for these types of functions because it's not idiomatic and hard to follow. I took the above QuoteString() and replaced it with something that uses an iterator:

func QuoteString(dst []byte, str string) []byte {
        const quote = '\''

        // Preallocate space for the worst case scenario
        dst = slices.Grow(dst, len(str)*2+2)

        // Add opening quote
        dst = append(dst, quote)

        // Iterate through the string without allocating
        for i := 0; i < len(str); i++ {
                if str[i] == quote {
                        dst = append(dst, quote, quote)
                } else {
                        dst = append(dst, str[i])
                }
        }

        // Add closing quote
        dst = append(dst, quote)

        return dst
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}

// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1

// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}

// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]

// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'

// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)

// Add suffix
dst[len(dst)-1] = '\''

return dst
}
Copy link
Contributor

@sean- sean- Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to measure an improvement by optimizing this function:

func QuoteBytes(dst, buf []byte) []byte {
        if len(buf) == 0 {
                return append(dst, `'\x'`...)
        }

        // Calculate required length
        requiredLen := 3 + hex.EncodedLen(len(buf)) + 1

        // Ensure dst has enough capacity
        if cap(dst)-len(dst) < requiredLen {
                newDst := make([]byte, len(dst), len(dst)+requiredLen)
                copy(newDst, dst)
                dst = newDst
        }

        // Record original length and extend slice
        origLen := len(dst)
        dst = dst[:origLen+requiredLen]

        // Add prefix
        dst[origLen] = '\''
        dst[origLen+1] = '\\'
        dst[origLen+2] = 'x'

        // Encode bytes directly into dst
        hex.Encode(dst[origLen+3:len(dst)-1], buf)

        // Add suffix
        dst[len(dst)-1] = '\''

        return dst
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


type sqlLexer struct {
Expand Down Expand Up @@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
}
}

var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}

// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)

return query.Sanitize(args...)
}

type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}

func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}

return v
}

func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}
62 changes: 62 additions & 0 deletions internal/sanitize/sanitize_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test

import (
"testing"
"time"

"github.com/jackc/pgx/v5/internal/sanitize"
)

var benchmarkSanitizeResult string

const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`

var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}

func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}

var benchmarkNewSQLResult string

func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}
Loading