diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 7efb522a4..723689264 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2093,6 +2093,22 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para p.conn.frontend.SendExecute(&pgproto3.Execute{}) } +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.Send(&pgproto3.Flush{}) +} + // Flush flushes the queued requests without establishing a synchronization point. func (p *Pipeline) Flush() error { if p.closed { @@ -2157,6 +2173,23 @@ func (p *Pipeline) GetResults() (results any, err error) { return p.getResults() } +// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, +// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. +// +// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush, +// without using Sync. In this case, you need to identify on your own when all results are received and +// there is no need to call the method anymore. +func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + + return p.getResults() +} + func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index a53061d18..b56d11fd9 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3003,6 +3003,70 @@ func TestPipelinePrepareQuery(t *testing.T) { ensureConnValid(t, pgConn) } +func TestPipelinePrepareQueryWithFlush(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResultsNotCheckSync() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { t.Parallel()