Skip to content

Commit

Permalink
Fix inFlightWG race in the OTel-Arrow receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
jmacd committed Jun 4, 2024
1 parent 367e229 commit 9726e00
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions collector/receiver/otelarrowreceiver/internal/arrow/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ type Receiver struct {
recvInFlightItems metric.Int64UpDownCounter
recvInFlightRequests metric.Int64UpDownCounter
boundedQueue *admission.BoundedQueue
inFlightWG sync.WaitGroup
}

// receiverStream holds the inFlightWG for a single stream.
type receiverStream struct {
*Receiver
inFlightWG sync.WaitGroup
}

// New creates a new Receiver reference.
Expand Down Expand Up @@ -306,9 +311,9 @@ func (r *Receiver) logStreamError(err error, where string) {
}

if code == codes.Canceled {
r.telemetry.Logger.Debug("arrow stream shutdown", zap.String("message", msg))
r.telemetry.Logger.Debug("arrow stream shutdown", zap.String("message", msg), zap.String("where", where))
} else {
r.telemetry.Logger.Error("arrow stream error", zap.String("message", msg), zap.Int("code", int(code)), zap.String("where", where))
r.telemetry.Logger.Error("arrow stream error", zap.Int("code", int(code)), zap.String("message", msg), zap.String("where", where))
}
}

Expand Down Expand Up @@ -381,34 +386,34 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr

// wg is used to ensure this thread returns after both
// sender and recevier threads return.
var wg sync.WaitGroup
wg.Add(2)
var sendWG sync.WaitGroup
var recvWG sync.WaitGroup
sendWG.Add(1)
recvWG.Add(1)

// The inflightWG is used to wait for all data to send. The
// 1-count here is removed after srvReceiveLoop() returns,
// having this ensures that concurrent calls to Add() in the
// receiver do not race with Wait() in the sender.
r.inFlightWG.Add(1)
rstream := &receiverStream{
Receiver: r,
}

go func() {
var err error
defer wg.Done()
defer recvWG.Done()
defer r.recoverErr(&err)
defer r.inFlightWG.Done()
err = r.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac)
err = rstream.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac)
streamErrCh <- err
}()

go func() {
var err error
defer wg.Done()
defer sendWG.Done()
defer r.recoverErr(&err)
err = r.srvSendLoop(doneCtx, serverStream, pendingCh)
err = rstream.srvSendLoop(doneCtx, serverStream, &recvWG, pendingCh)
streamErrCh <- err
}()

// Wait for sender/receiver threads to return before returning.
defer wg.Wait()
defer recvWG.Wait()
defer sendWG.Wait()

select {
case <-doneCtx.Done():
Expand All @@ -419,17 +424,17 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr
}
}

func (r *Receiver) newInFlightData(ctx context.Context, method string, batchID int64, pendingCh chan<- batchResp) (context.Context, *inFlightData) {
func (r *receiverStream) newInFlightData(ctx context.Context, method string, batchID int64, pendingCh chan<- batchResp) (context.Context, *inFlightData) {
ctx, span := r.tracer.Start(ctx, "otel_arrow_stream_inflight")

r.inFlightWG.Add(1)
r.recvInFlightRequests.Add(ctx, 1)
id := &inFlightData{
Receiver: r,
method: method,
batchID: batchID,
pendingCh: pendingCh,
span: span,
receiverStream: r,
method: method,
batchID: batchID,
pendingCh: pendingCh,
span: span,
}
id.refs.Add(1)
return ctx, id
Expand All @@ -438,7 +443,7 @@ func (r *Receiver) newInFlightData(ctx context.Context, method string, batchID i
// inFlightData is responsible for storing the resources held by one request.
type inFlightData struct {
// Receiver is the owner of the resources held by this object.
*Receiver
*receiverStream

method string
batchID int64
Expand Down Expand Up @@ -539,7 +544,7 @@ func (id *inFlightData) anyDone(ctx context.Context) {
// This handles constructing an inFlightData object, which itself
// tracks everything that needs to be used by instrumention when the
// batch finishes.
func (r *Receiver) recvOne(streamCtx context.Context, serverStream anyStreamServer, hrcv *headerReceiver, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) {
func (r *receiverStream) recvOne(streamCtx context.Context, serverStream anyStreamServer, hrcv *headerReceiver, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) {

// Receive a batch corresponding with one ptrace.Traces, pmetric.Metrics,
// or plog.Logs item.
Expand Down Expand Up @@ -650,7 +655,7 @@ func (r *Receiver) consumeAndRespond(ctx context.Context, data any, flight *inFl
}

// srvReceiveLoop repeatedly receives one batch of data.
func (r *Receiver) srvReceiveLoop(ctx context.Context, serverStream anyStreamServer, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) {
func (r *receiverStream) srvReceiveLoop(ctx context.Context, serverStream anyStreamServer, pendingCh chan<- batchResp, method string, ac arrowRecord.ConsumerAPI) (retErr error) {
hrcv := newHeaderReceiver(ctx, r.authServer, r.gsettings.IncludeMetadata)
for {
select {
Expand All @@ -665,7 +670,7 @@ func (r *Receiver) srvReceiveLoop(ctx context.Context, serverStream anyStreamSer
}

// srvReceiveLoop repeatedly sends one batch data response.
func (r *Receiver) sendOne(serverStream anyStreamServer, resp batchResp) error {
func (r *receiverStream) sendOne(serverStream anyStreamServer, resp batchResp) error {
// Note: Statuses can be batched, but we do not take
// advantage of this feature.
bs := &arrowpb.BatchStatus{
Expand Down Expand Up @@ -709,19 +714,17 @@ func (r *Receiver) sendOne(serverStream anyStreamServer, resp batchResp) error {
return nil
}

func (r *Receiver) flushSender(serverStream anyStreamServer, pendingCh <-chan batchResp) error {
var err error
// wait for all in flight requests to be successfully
// processed or fail. this implies waiting for the receiver
// loop to exit, as it holds one additional wait count to
// avoid a race with Add() here.
func (r *receiverStream) flushSender(serverStream anyStreamServer, recvWG *sync.WaitGroup, pendingCh <-chan batchResp) error {
// wait to ensure no more items are accepted
recvWG.Wait()

// wait for all responses to be sent
r.inFlightWG.Wait()

for {
select {
case resp := <-pendingCh:
err = r.sendOne(serverStream, resp)
if err != nil {
if err := r.sendOne(serverStream, resp); err != nil {
return err
}
default:
Expand All @@ -731,11 +734,11 @@ func (r *Receiver) flushSender(serverStream anyStreamServer, pendingCh <-chan ba
}
}

func (r *Receiver) srvSendLoop(ctx context.Context, serverStream anyStreamServer, pendingCh <-chan batchResp) error {
func (r *receiverStream) srvSendLoop(ctx context.Context, serverStream anyStreamServer, recvWG *sync.WaitGroup, pendingCh <-chan batchResp) error {
for {
select {
case <-ctx.Done():
return r.flushSender(serverStream, pendingCh)
return r.flushSender(serverStream, recvWG, pendingCh)
case resp := <-pendingCh:
if err := r.sendOne(serverStream, resp); err != nil {
return err
Expand Down

0 comments on commit 9726e00

Please sign in to comment.