diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 4aca2fb98..fd1e808b4 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -70,9 +70,9 @@ func (q *Query) Sanitize(args ...any) (string, error) { case bool: p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - p = quoteBytes(buf.AvailableBuffer(), arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - p = quoteString(buf.AvailableBuffer(), arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: p = arg.Truncate(time.Microsecond). AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") @@ -135,11 +135,7 @@ func (q *Query) init(sql string) { q.Parts = l.parts } -func QuoteString(str string) string { - return string(quoteString(nil, str)) -} - -func quoteString(dst []byte, str string) []byte { +func QuoteString(dst []byte, str string) []byte { const quote = "'" n := strings.Count(str, quote) @@ -166,11 +162,7 @@ func quoteString(dst []byte, str string) []byte { return dst } -func QuoteBytes(buf []byte) string { - return string(quoteBytes(nil, buf)) -} - -func quoteBytes(dst, buf []byte) []byte { +func QuoteBytes(dst, buf []byte) []byte { dst = append(dst, `'\x`...) n := hex.EncodedLen(len(buf)) diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go index 7d594def0..746558276 100644 --- a/internal/sanitize/sanitize_fuzz_test.go +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -14,10 +14,10 @@ func FuzzQuoteString(f *testing.F) { f.Add("select 'quoted $42', $1") f.Fuzz(func(t *testing.T, input string) { - got := sanitize.QuoteString(input) + got := sanitize.QuoteString(nil, input) want := oldQuoteString(input) - if want != got { + if want != string(got) { t.Errorf("got %q", got) t.Fatalf("want %q", want) } @@ -32,10 +32,10 @@ func FuzzQuoteBytes(f *testing.F) { f.Add([]byte("select 'quoted $42', $1")) f.Fuzz(func(t *testing.T, input []byte) { - got := sanitize.QuoteBytes(input) + got := sanitize.QuoteBytes(nil, input) want := oldQuoteBytes(input) - if want != got { + if want != string(got) { t.Errorf("got %q", got) t.Fatalf("want %q", want) } diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index aafcd682d..9da701ea9 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -235,7 +235,7 @@ func TestQuoteString(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - got := sanitize.QuoteString(input) + got := string(sanitize.QuoteString(nil, input)) want := oldQuoteString(input) if got != want { @@ -259,7 +259,7 @@ func TestQuoteBytes(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - got := sanitize.QuoteBytes(input) + got := string(sanitize.QuoteBytes(nil, input)) want := oldQuoteBytes(input) if got != want {