Skip to content

Commit

Permalink
Prepare chooses statement name based on sql if name == sql
Browse files Browse the repository at this point in the history
This makes it easier to explicitly manage prepared statements.

refs #1716
  • Loading branch information
jackc committed Sep 23, 2023
1 parent 4e7aa59 commit bbe2653
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
43 changes: 32 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package pgx

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
Expand Down Expand Up @@ -284,12 +286,15 @@ func (c *Conn) Close(ctx context.Context) error {
return err
}

// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
// Exec to execute the statement. It can also be used with Batch.Queue.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
// name == sql.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if c.prepareTracer != nil {
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
Expand All @@ -311,22 +316,38 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
}()
}

sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
var psName, psKey string
if name == sql {
digest := sha256.Sum256([]byte(sql))
psName = "stmt_" + hex.EncodeToString(digest[0:24])
psKey = sql
} else {
psName = name
psKey = name
}

sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
if err != nil {
return nil, err
}

if name != "" {
c.preparedStatements[name] = sd
if psKey != "" {
c.preparedStatements[psKey] = sd
}

return sd, nil
}

// Deallocate released a prepared statement
// Deallocate releases a prepared statement.
func (c *Conn) Deallocate(ctx context.Context, name string) error {
delete(c.preparedStatements, name)
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
var psName string
if sd, ok := c.preparedStatements[name]; ok {
delete(c.preparedStatements, name)
psName = sd.Name
} else {
psName = name
}
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(psName)).ReadAll()
return err
}

Expand Down
22 changes: 22 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,28 @@ func TestPrepareStatementCacheModes(t *testing.T) {
})
}

func TestPrepareWithDigestedName(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := "select $1::text"
sd, err := conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)

var s string
err = conn.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)

err = conn.Deallocate(ctx, sql)
require.NoError(t, err)
})
}

func TestListenNotify(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit bbe2653

Please sign in to comment.