From 9726e00be38dab11a43454bf047c717b020ba06b Mon Sep 17 00:00:00 2001 From: Joshua MacDonald Date: Tue, 4 Jun 2024 15:00:03 -0700 Subject: [PATCH] Fix inFlightWG race in the OTel-Arrow receiver --- .../otelarrowreceiver/internal/arrow/arrow.go | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/collector/receiver/otelarrowreceiver/internal/arrow/arrow.go b/collector/receiver/otelarrowreceiver/internal/arrow/arrow.go index 4159822e..5515b7fe 100644 --- a/collector/receiver/otelarrowreceiver/internal/arrow/arrow.go +++ b/collector/receiver/otelarrowreceiver/internal/arrow/arrow.go @@ -83,7 +83,12 @@ type Receiver struct { recvInFlightItems metric.Int64UpDownCounter recvInFlightRequests metric.Int64UpDownCounter boundedQueue *admission.BoundedQueue - inFlightWG sync.WaitGroup +} + +// receiverStream holds the inFlightWG for a single stream. +type receiverStream struct { + *Receiver + inFlightWG sync.WaitGroup } // New creates a new Receiver reference. @@ -306,9 +311,9 @@ func (r *Receiver) logStreamError(err error, where string) { } if code == codes.Canceled { - r.telemetry.Logger.Debug("arrow stream shutdown", zap.String("message", msg)) + r.telemetry.Logger.Debug("arrow stream shutdown", zap.String("message", msg), zap.String("where", where)) } else { - r.telemetry.Logger.Error("arrow stream error", zap.String("message", msg), zap.Int("code", int(code)), zap.String("where", where)) + r.telemetry.Logger.Error("arrow stream error", zap.Int("code", int(code)), zap.String("message", msg), zap.String("where", where)) } } @@ -381,34 +386,34 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr // wg is used to ensure this thread returns after both // sender and recevier threads return. - var wg sync.WaitGroup - wg.Add(2) + var sendWG sync.WaitGroup + var recvWG sync.WaitGroup + sendWG.Add(1) + recvWG.Add(1) - // The inflightWG is used to wait for all data to send. The - // 1-count here is removed after srvReceiveLoop() returns, - // having this ensures that concurrent calls to Add() in the - // receiver do not race with Wait() in the sender. - r.inFlightWG.Add(1) + rstream := &receiverStream{ + Receiver: r, + } go func() { var err error - defer wg.Done() + defer recvWG.Done() defer r.recoverErr(&err) - defer r.inFlightWG.Done() - err = r.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac) + err = rstream.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac) streamErrCh <- err }() go func() { var err error - defer wg.Done() + defer sendWG.Done() defer r.recoverErr(&err) - err = r.srvSendLoop(doneCtx, serverStream, pendingCh) + err = rstream.srvSendLoop(doneCtx, serverStream, &recvWG, pendingCh) streamErrCh <- err }() // Wait for sender/receiver threads to return before returning. - defer wg.Wait() + defer recvWG.Wait() + defer sendWG.Wait() select { case <-doneCtx.Done(): @@ -419,17 +424,17 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr } } -func (r *Receiver) newInFlightData(ctx context.Context, method string, batchID int64, pendingCh chan<- batchResp) (context.Context, *inFlightData) { +func (r *receiverStream) newInFlightData(ctx context.Context, method string, batchID int64, pendingCh chan<- batchResp) (context.Context, *inFlightData) { ctx, span := r.tracer.Start(ctx, "otel_arrow_stream_inflight") r.inFlightWG.Add(1) r.recvInFlightRequests.Add(ctx, 1) id := &inFlightData{ - Receiver: r, - method: method, - batchID: batchID, - pendingCh: pendingCh, - span: span, + receiverStream: r, + method: method, + batchID: batchID, + pendingCh: pendingCh, + span: span, } id.refs.Add(1) return ctx, id @@ -438,7 +443,7 @@ func (r *Receiver) newInFlightData(ctx context.Context, method string, batchID i // inFlightData is responsible for storing the resources held by one request. type inFlightData struct { // Receiver is the owner of the resources held by this object. - *Receiver + *receiverStream method string batchID int64 @@ -539,7 +544,7 @@ func (id *inFlightData) anyDone(ctx context.Context) { // This handles constructing an inFlightData object, which itself // tracks everything that needs to be used by instrumention when the // batch finishes. -func (r *Receiver) recvOne(streamCtx context.Context, serverStream anyStreamServer, hrcv *headerReceiver, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) { +func (r *receiverStream) recvOne(streamCtx context.Context, serverStream anyStreamServer, hrcv *headerReceiver, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) { // Receive a batch corresponding with one ptrace.Traces, pmetric.Metrics, // or plog.Logs item. @@ -650,7 +655,7 @@ func (r *Receiver) consumeAndRespond(ctx context.Context, data any, flight *inFl } // srvReceiveLoop repeatedly receives one batch of data. -func (r *Receiver) srvReceiveLoop(ctx context.Context, serverStream anyStreamServer, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) { +func (r *receiverStream) srvReceiveLoop(ctx context.Context, serverStream anyStreamServer, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) { hrcv := newHeaderReceiver(ctx, r.authServer, r.gsettings.IncludeMetadata) for { select { @@ -665,7 +670,7 @@ func (r *Receiver) srvReceiveLoop(ctx context.Context, serverStream anyStreamSer } // srvReceiveLoop repeatedly sends one batch data response. -func (r *Receiver) sendOne(serverStream anyStreamServer, resp batchResp) error { +func (r *receiverStream) sendOne(serverStream anyStreamServer, resp batchResp) error { // Note: Statuses can be batched, but we do not take // advantage of this feature. bs := &arrowpb.BatchStatus{ @@ -709,19 +714,17 @@ func (r *Receiver) sendOne(serverStream anyStreamServer, resp batchResp) error { return nil } -func (r *Receiver) flushSender(serverStream anyStreamServer, pendingCh <-chan batchResp) error { - var err error - // wait for all in flight requests to be successfully - // processed or fail. this implies waiting for the receiver - // loop to exit, as it holds one additional wait count to - // avoid a race with Add() here. +func (r *receiverStream) flushSender(serverStream anyStreamServer, recvWG *sync.WaitGroup, pendingCh <-chan batchResp) error { + // wait to ensure no more items are accepted + recvWG.Wait() + + // wait for all responses to be sent r.inFlightWG.Wait() for { select { case resp := <-pendingCh: - err = r.sendOne(serverStream, resp) - if err != nil { + if err := r.sendOne(serverStream, resp); err != nil { return err } default: @@ -731,11 +734,11 @@ func (r *Receiver) flushSender(serverStream anyStreamServer, pendingCh <-chan ba } } -func (r *Receiver) srvSendLoop(ctx context.Context, serverStream anyStreamServer, pendingCh <-chan batchResp) error { +func (r *receiverStream) srvSendLoop(ctx context.Context, serverStream anyStreamServer, recvWG *sync.WaitGroup, pendingCh <-chan batchResp) error { for { select { case <-ctx.Done(): - return r.flushSender(serverStream, pendingCh) + return r.flushSender(serverStream, recvWG, pendingCh) case resp := <-pendingCh: if err := r.sendOne(serverStream, resp); err != nil { return err