From 76593f37f74ec44e8e2ef2714d0785c886fcee44 Mon Sep 17 00:00:00 2001 From: zenkovev <89307802901@mail.ru> Date: Tue, 17 Dec 2024 11:49:13 +0300 Subject: [PATCH 1/3] add flush request in pipeline --- pgconn/pgconn.go | 33 ++++++++++++++++++++++ pgconn/pgconn_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 7efb522a4..723689264 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2093,6 +2093,22 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para p.conn.frontend.SendExecute(&pgproto3.Execute{}) } +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.Send(&pgproto3.Flush{}) +} + // Flush flushes the queued requests without establishing a synchronization point. func (p *Pipeline) Flush() error { if p.closed { @@ -2157,6 +2173,23 @@ func (p *Pipeline) GetResults() (results any, err error) { return p.getResults() } +// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, +// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. +// +// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush, +// without using Sync. In this case, you need to identify on your own when all results are received and +// there is no need to call the method anymore. +func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + + return p.getResults() +} + func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index a53061d18..b56d11fd9 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3003,6 +3003,70 @@ func TestPipelinePrepareQuery(t *testing.T) { ensureConnValid(t, pgConn) } +func TestPipelinePrepareQueryWithFlush(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { t.Parallel() From de3f868c1d4c46fd53cf479ce2855bef177391d9 Mon Sep 17 00:00:00 2001 From: zenkovev Date: Mon, 6 Jan 2025 13:54:48 +0300 Subject: [PATCH 2/3] pipeline queue for client requests --- pgconn/pgconn.go | 220 ++++++++++++++++++++-------- pgconn/pgconn_test.go | 324 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 461 insertions(+), 83 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 723689264..28ee01a77 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "container/list" "context" "crypto/md5" "crypto/tls" @@ -1408,9 +1409,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - pipeline *Pipeline + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -1443,12 +1443,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.closed = true - if mrr.pipeline != nil { - mrr.pipeline.expectedReadyForQueryCount-- - } else { - mrr.pgConn.contextWatcher.Unwatch() - mrr.pgConn.unlock() - } + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1672,7 +1668,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.EmptyQueryResponse: rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) } return msg, nil @@ -1999,9 +1999,7 @@ type Pipeline struct { conn *PgConn ctx context.Context - expectedReadyForQueryCount int - pendingSync bool - + state pipelineState err error closed bool } @@ -2012,6 +2010,122 @@ type PipelineSync struct{} // CloseComplete is returned by GetResults when a CloseComplete message is received. type CloseComplete struct{} +type pipelineRequestType int + +const ( + PIPELINE_NIL pipelineRequestType = iota + PIPELINE_PREPARE + PIPELINE_QUERY_PARAMS + PIPELINE_QUERY_PREPARED + PIPELINE_DEALLOCATE + PIPELINE_SYNC_REQUEST + PIPELINE_FLUSH_REQUEST +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.lastRequestType = PIPELINE_NIL +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == PIPELINE_NIL { + return + } + + if req != PIPELINE_FLUSH_REQUEST { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == PIPELINE_FLUSH_REQUEST || req == PIPELINE_SYNC_REQUEST { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == PIPELINE_SYNC_REQUEST { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return PIPELINE_NIL + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return PIPELINE_NIL + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == PIPELINE_SYNC_REQUEST { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == PIPELINE_SYNC_REQUEST) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == PIPELINE_SYNC_REQUEST) || (s.lastRequestType == PIPELINE_NIL) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -2020,16 +2134,21 @@ type CloseComplete struct{} // Prefer ExecBatch when only sending one group of queries at once. func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { if err := pgConn.lock(); err != nil { - return &Pipeline{ + pipeline := &Pipeline{ closed: true, err: err, } + pipeline.state.Init() + + return pipeline } pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, } + pgConn.pipeline.state.Init() + pipeline := &pgConn.pipeline if ctx != context.Background() { @@ -2052,10 +2171,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(PIPELINE_PREPARE) } // SendDeallocate deallocates a prepared statement. @@ -2063,9 +2182,9 @@ func (p *Pipeline) SendDeallocate(name string) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(PIPELINE_DEALLOCATE) } // SendQueryParams is the pipeline version of *PgConn.QueryParams. @@ -2073,12 +2192,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [ if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(PIPELINE_QUERY_PARAMS) } // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. @@ -2086,11 +2205,11 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para if p.closed { return } - p.pendingSync = true p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(PIPELINE_QUERY_PREPARED) } // SendFlushRequest sends a request for the server to flush its output buffer. @@ -2104,9 +2223,24 @@ func (p *Pipeline) SendFlushRequest() { if p.closed { return } - p.pendingSync = true p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(PIPELINE_FLUSH_REQUEST) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(PIPELINE_SYNC_REQUEST) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2131,28 +2265,14 @@ func (p *Pipeline) Flush() error { return err } + p.state.RegisterSendingToServer() return nil } // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { - if p.closed { - if p.err != nil { - return p.err - } - return errors.New("pipeline closed") - } - - p.conn.frontend.SendSync(&pgproto3.Sync{}) - err := p.Flush() - if err != nil { - return err - } - - p.pendingSync = false - p.expectedReadyForQueryCount++ - - return nil + p.SendPipelineSync() + return p.Flush() } // GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or @@ -2166,30 +2286,13 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.expectedReadyForQueryCount == 0 { + if p.state.ExtractFrontRequestType() == PIPELINE_NIL { return nil, nil } return p.getResults() } -// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, -// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. -// -// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush, -// without using Sync. In this case, you need to identify on your own when all results are received and -// there is no need to call the method anymore. -func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) { - if p.closed { - if p.err != nil { - return nil, p.err - } - return nil, errors.New("pipeline closed") - } - - return p.getResults() -} - func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() @@ -2228,13 +2331,13 @@ func (p *Pipeline) getResults() (results any, err error) { case *pgproto3.CloseComplete: return &CloseComplete{}, nil case *pgproto3.ReadyForQuery: - p.expectedReadyForQueryCount-- + p.state.HandleReadyForQuery() return &PipelineSync{}, nil case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr } - } } @@ -2264,6 +2367,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { // These should never happen here. But don't take chances that could lead to a deadlock. case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr case *pgproto3.CommandComplete: p.conn.asyncClose() @@ -2283,7 +2387,7 @@ func (p *Pipeline) Close() error { p.closed = true - if p.pendingSync { + if p.state.PendingSync() { p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") p.conn.contextWatcher.Unwatch() @@ -2292,7 +2396,7 @@ func (p *Pipeline) Close() error { return p.err } - for p.expectedReadyForQueryCount > 0 { + for p.state.ExpectedReadyForQuery() > 0 { _, err := p.getResults() if err != nil { p.err = err diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b56d11fd9..f800677d8 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3003,7 +3003,114 @@ func TestPipelinePrepareQuery(t *testing.T) { ensureConnValid(t, pgConn) } -func TestPipelinePrepareQueryWithFlush(t *testing.T) { +func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushForSingleRequests(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -3014,14 +3121,13 @@ func TestPipelinePrepareQueryWithFlush(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) - pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) - pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) pipeline.SendFlushRequest() err = pipeline.Flush() require.NoError(t, err) - results, err := pipeline.GetResultsNotCheckSync() + results, err := pipeline.GetResults() require.NoError(t, err) sd, ok := results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) @@ -3029,7 +3135,16 @@ func TestPipelinePrepareQueryWithFlush(t *testing.T) { require.Equal(t, "msg", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) - results, err = pipeline.GetResultsNotCheckSync() + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() require.NoError(t, err) rr, ok := results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) @@ -3039,7 +3154,30 @@ func TestPipelinePrepareQueryWithFlush(t *testing.T) { require.Len(t, readResult.Rows[0], 1) require.Equal(t, "hello", string(readResult.Rows[0][0])) - results, err = pipeline.GetResultsNotCheckSync() + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendDeallocate("ps") + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() require.NoError(t, err) rr, ok = results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) @@ -3047,7 +3185,11 @@ func TestPipelinePrepareQueryWithFlush(t *testing.T) { require.NoError(t, readResult.Err) require.Len(t, readResult.Rows, 1) require.Len(t, readResult.Rows[0], 1) - require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) err = pipeline.Sync() require.NoError(t, err) @@ -3067,7 +3209,7 @@ func TestPipelinePrepareQueryWithFlush(t *testing.T) { ensureConnValid(t, pgConn) } -func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { +func TestPipelineFlushForRequestSeries(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -3078,23 +3220,30 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(ctx) - pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendPrepare("ps", "select $1::bigint as num", nil) err = pipeline.Sync() require.NoError(t, err) - pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) - err = pipeline.Sync() + results, err := pipeline.GetResults() require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "num", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) - pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) - err = pipeline.Sync() + results, err = pipeline.GetResults() require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) - results, err := pipeline.GetResults() + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() require.NoError(t, err) rr, ok := results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) @@ -3116,8 +3265,20 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { results, err = pipeline.GetResults() require.NoError(t, err) - _, ok = results.(*pgconn.PipelineSync) - require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil) + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) results, err = pipeline.GetResults() require.NoError(t, err) @@ -3134,14 +3295,26 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { rr, ok = results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) readResult = rr.Read() - var pgErr *pgconn.PgError - require.ErrorAs(t, readResult.Err, &pgErr) - require.Equal(t, "22012", pgErr.Code) + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) results, err = pipeline.GetResults() require.NoError(t, err) - _, ok = results.(*pgconn.PipelineSync) - require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil) + pipeline.SendFlushRequest() + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) results, err = pipeline.GetResults() require.NoError(t, err) @@ -3163,6 +3336,107 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { require.Len(t, readResult.Rows[0], 1) require.Equal(t, "6", string(readResult.Rows[0][0])) + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushWithError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendPipelineSync() + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + results, err = pipeline.GetResults() require.NoError(t, err) _, ok = results.(*pgconn.PipelineSync) From c96a55f8c0d90def96cb626f44e2948daf7da90f Mon Sep 17 00:00:00 2001 From: zenkovev Date: Sat, 11 Jan 2025 19:54:18 +0300 Subject: [PATCH 3/3] private const for pipelineRequestType --- pgconn/pgconn.go | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 28ee01a77..59b89cf7d 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2013,13 +2013,13 @@ type CloseComplete struct{} type pipelineRequestType int const ( - PIPELINE_NIL pipelineRequestType = iota - PIPELINE_PREPARE - PIPELINE_QUERY_PARAMS - PIPELINE_QUERY_PREPARED - PIPELINE_DEALLOCATE - PIPELINE_SYNC_REQUEST - PIPELINE_FLUSH_REQUEST + pipelineNil pipelineRequestType = iota + pipelinePrepare + pipelineQueryParams + pipelineQueryPrepared + pipelineDeallocate + pipelineSyncRequest + pipelineFlushRequest ) type pipelineRequestEvent struct { @@ -2037,7 +2037,7 @@ type pipelineState struct { func (s *pipelineState) Init() { s.requestEventQueue.Init() - s.lastRequestType = PIPELINE_NIL + s.lastRequestType = pipelineNil } func (s *pipelineState) RegisterSendingToServer() { @@ -2063,19 +2063,19 @@ func (s *pipelineState) registerFlushingBufferOnServer() { } func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { - if req == PIPELINE_NIL { + if req == pipelineNil { return } - if req != PIPELINE_FLUSH_REQUEST { + if req != pipelineFlushRequest { s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) } - if req == PIPELINE_FLUSH_REQUEST || req == PIPELINE_SYNC_REQUEST { + if req == pipelineFlushRequest || req == pipelineSyncRequest { s.registerFlushingBufferOnServer() } s.lastRequestType = req - if req == PIPELINE_SYNC_REQUEST { + if req == pipelineSyncRequest { s.expectedReadyForQueryCount++ } } @@ -2084,15 +2084,15 @@ func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { for { elem := s.requestEventQueue.Front() if elem == nil { - return PIPELINE_NIL + return pipelineNil } val := elem.Value.(pipelineRequestEvent) if !(val.WasSentToServer && val.BeforeFlushOrSync) { - return PIPELINE_NIL + return pipelineNil } s.requestEventQueue.Remove(elem) - if val.RequestType == PIPELINE_SYNC_REQUEST { + if val.RequestType == pipelineSyncRequest { s.pgErr = nil } if s.pgErr == nil { @@ -2114,9 +2114,9 @@ func (s *pipelineState) PendingSync() bool { if elem := s.requestEventQueue.Back(); elem != nil { val := elem.Value.(pipelineRequestEvent) - notPendingSync = (val.RequestType == PIPELINE_SYNC_REQUEST) && val.WasSentToServer + notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer } else { - notPendingSync = (s.lastRequestType == PIPELINE_SYNC_REQUEST) || (s.lastRequestType == PIPELINE_NIL) + notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil) } return !notPendingSync @@ -2174,7 +2174,7 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) - p.state.PushBackRequestType(PIPELINE_PREPARE) + p.state.PushBackRequestType(pipelinePrepare) } // SendDeallocate deallocates a prepared statement. @@ -2184,7 +2184,7 @@ func (p *Pipeline) SendDeallocate(name string) { } p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) - p.state.PushBackRequestType(PIPELINE_DEALLOCATE) + p.state.PushBackRequestType(pipelineDeallocate) } // SendQueryParams is the pipeline version of *PgConn.QueryParams. @@ -2197,7 +2197,7 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [ p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) - p.state.PushBackRequestType(PIPELINE_QUERY_PARAMS) + p.state.PushBackRequestType(pipelineQueryParams) } // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. @@ -2209,7 +2209,7 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) - p.state.PushBackRequestType(PIPELINE_QUERY_PREPARED) + p.state.PushBackRequestType(pipelineQueryPrepared) } // SendFlushRequest sends a request for the server to flush its output buffer. @@ -2225,7 +2225,7 @@ func (p *Pipeline) SendFlushRequest() { } p.conn.frontend.Send(&pgproto3.Flush{}) - p.state.PushBackRequestType(PIPELINE_FLUSH_REQUEST) + p.state.PushBackRequestType(pipelineFlushRequest) } // SendPipelineSync marks a synchronization point in a pipeline by sending a sync message @@ -2240,7 +2240,7 @@ func (p *Pipeline) SendPipelineSync() { } p.conn.frontend.SendSync(&pgproto3.Sync{}) - p.state.PushBackRequestType(PIPELINE_SYNC_REQUEST) + p.state.PushBackRequestType(pipelineSyncRequest) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2286,7 +2286,7 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.state.ExtractFrontRequestType() == PIPELINE_NIL { + if p.state.ExtractFrontRequestType() == pipelineNil { return nil, nil }