Skip to content

Commit

Permalink
Fix: prepared statement already exists
Browse files Browse the repository at this point in the history
When a conn is going to execute a query, the first thing it does is to
deallocate any invalidated prepared statements from the statement cache.
However, the statements were removed from the cache regardless of
whether the deallocation succeeded. This would cause subsequent calls of
the same SQL to fail with "prepared statement already exists" error.

This problem is easy to trigger by running a query with a context that
is already canceled.

This commit changes the deallocate invalidated cached statements logic
so that the statements are only removed from the cache if the
deallocation was successful on the server.

#1847
  • Loading branch information
jackc committed Feb 3, 2024
1 parent fd44114 commit 832b4f9
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 13 deletions.
10 changes: 7 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1359,12 +1359,12 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
}

if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
c.descriptionCache.RemoveInvalidated()
}

var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
invalidatedStatements = c.statementCache.GetInvalidated()
}

if len(invalidatedStatements) == 0 {
Expand All @@ -1376,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error

for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
delete(c.preparedStatements, sd.Name)
}

err := pipeline.Sync()
Expand All @@ -1389,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}

c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}

return nil
}
29 changes: 29 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1338,3 +1338,32 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
})
}

// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var n int32
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)

// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
// we could call conn.statementCache.InvalidateAll() instead.
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
require.Error(t, err)

ctx2, cancel2 := context.WithCancel(ctx)
cancel2()
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)

err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}
14 changes: 9 additions & 5 deletions internal/stmtcache/lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New()
}

// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
// Typically, the caller will then deallocate them.
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}

// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}

// Len returns the number of cached prepared statement descriptions.
Expand Down
9 changes: 7 additions & 2 deletions internal/stmtcache/stmtcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()

// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription

// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()

// Len returns the number of cached prepared statement descriptions.
Len() int
Expand Down
12 changes: 9 additions & 3 deletions internal/stmtcache/unlimited_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription)
}

func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}

// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}

// Len returns the number of cached prepared statement descriptions.
Expand Down

0 comments on commit 832b4f9

Please sign in to comment.