Skip to content

Commit

Permalink
Add context to ProducerInterceptor
Browse files Browse the repository at this point in the history
This change adds a context.Context argument to the ProducerInterceptor
interface, and passes it between the pre- and post-Send interceptor
methods. Having this makes it much easier to write useful
interceptors that can integrate with common tracing SDKs like
OpenTelemetry, as the context is the conventional method for propagating
metadata vertically through a call stack.

For an example of another library using a similar convention, see:
https://github.com/jackc/pgx/blob/9ab9e3c40bbb33c6f37359c87508cbc6a9830ed6/tracer.go#L10

Fixes #443
  • Loading branch information
treuherz committed Jan 17, 2024
1 parent 4e13822 commit 011d1f2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 37 deletions.
15 changes: 11 additions & 4 deletions pulsar/internal/pulsartracing/producer_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ const toPrefix = "To__"
type ProducerInterceptor struct {
}

func (t *ProducerInterceptor) BeforeSend(producer pulsar.Producer, message *pulsar.ProducerMessage) {
func (t *ProducerInterceptor) BeforeSend(
ctx context.Context,
producer pulsar.Producer,
message *pulsar.ProducerMessage,
) context.Context {
buildAndInjectSpan(message, producer).Finish()
return ctx
}

func (t *ProducerInterceptor) OnSendAcknowledgement(producer pulsar.Producer,
message *pulsar.ProducerMessage,
msgID pulsar.MessageID) {
func (t *ProducerInterceptor) OnSendAcknowledgement(
_ context.Context,
_ pulsar.Producer,
_ *pulsar.ProducerMessage,
_ pulsar.MessageID) {
}

func buildAndInjectSpan(message *pulsar.ProducerMessage, producer pulsar.Producer) opentracing.Span {
Expand Down
15 changes: 9 additions & 6 deletions pulsar/producer_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,30 @@

package pulsar

import "context"

type ProducerInterceptor interface {
// BeforeSend This is called before send the message to the brokers. This method is allowed to modify the
// message.
BeforeSend(producer Producer, message *ProducerMessage)
BeforeSend(ctx context.Context, producer Producer, message *ProducerMessage) context.Context

// OnSendAcknowledgement This method is called when the message sent to the broker has been acknowledged,
// or when sending the message fails.
OnSendAcknowledgement(producer Producer, message *ProducerMessage, msgID MessageID)
OnSendAcknowledgement(ctx context.Context, producer Producer, message *ProducerMessage, msgID MessageID)
}

type ProducerInterceptors []ProducerInterceptor

func (x ProducerInterceptors) BeforeSend(producer Producer, message *ProducerMessage) {
func (x ProducerInterceptors) BeforeSend(ctx context.Context, producer Producer, message *ProducerMessage) context.Context {
for i := range x {
x[i].BeforeSend(producer, message)
ctx = x[i].BeforeSend(ctx, producer, message)
}
return ctx
}

func (x ProducerInterceptors) OnSendAcknowledgement(producer Producer, message *ProducerMessage, msgID MessageID) {
func (x ProducerInterceptors) OnSendAcknowledgement(ctx context.Context, producer Producer, message *ProducerMessage, msgID MessageID) {
for i := range x {
x[i].OnSendAcknowledgement(producer, message, msgID)
x[i].OnSendAcknowledgement(ctx, producer, message, msgID)
}
}

Expand Down
25 changes: 14 additions & 11 deletions pulsar/producer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ func (p *partitionProducer) Send(ctx context.Context, msg *ProducerMessage) (Mes
isDone := uAtomic.NewBool(false)
doneCh := make(chan struct{})

p.internalSendAsync(ctx, msg, func(ID MessageID, message *ProducerMessage, e error) {
ctx = p.internalSendAsync(ctx, msg, func(ID MessageID, message *ProducerMessage, e error) {
if isDone.CAS(false, true) {
err = e
msgID = ID
Expand Down Expand Up @@ -1202,11 +1202,11 @@ func (p *partitionProducer) internalSendAsync(
msg *ProducerMessage,
callback func(MessageID, *ProducerMessage, error),
flushImmediately bool,
) {
) context.Context {
if err := p.validateMsg(msg); err != nil {
p.log.Error(err)
runCallback(callback, nil, msg, err)
return
return ctx
}

sr := sendRequestPool.Get().(*sendRequest)
Expand All @@ -1224,43 +1224,46 @@ func (p *partitionProducer) internalSendAsync(

if err := p.prepareTransaction(sr); err != nil {
sr.done(nil, err)
return
return ctx
}

if p.getProducerState() != producerReady {
sr.done(nil, ErrProducerClosed)
return
return ctx
}

p.options.Interceptors.BeforeSend(p, msg)
ctx = p.options.Interceptors.BeforeSend(ctx, p, msg)
sr.ctx = ctx

if err := p.updateSchema(sr); err != nil {
p.log.Error(err)
sr.done(nil, err)
return
return ctx
}

if err := p.updateUncompressedPayload(sr); err != nil {
p.log.Error(err)
sr.done(nil, err)
return
return ctx
}

p.updateMetaData(sr)

if err := p.updateChunkInfo(sr); err != nil {
p.log.Error(err)
sr.done(nil, err)
return
return ctx
}

if err := p.reserveResources(sr); err != nil {
p.log.Error(err)
sr.done(nil, err)
return
return ctx
}

p.dataChan <- sr

return ctx
}

func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) {
Expand Down Expand Up @@ -1505,7 +1508,7 @@ func (sr *sendRequest) done(msgID MessageID, err error) {

if sr.totalChunks <= 1 || sr.chunkID == sr.totalChunks-1 {
if sr.producer.options.Interceptors != nil {
sr.producer.options.Interceptors.OnSendAcknowledgement(sr.producer, sr.msg, msgID)
sr.producer.options.Interceptors.OnSendAcknowledgement(sr.ctx, sr.producer, sr.msg, msgID)
}
}
}
Expand Down
48 changes: 32 additions & 16 deletions pulsar/producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1479,23 +1479,38 @@ func TestProducuerSendFailOnInvalidKey(t *testing.T) {

type noopProduceInterceptor struct{}

func (noopProduceInterceptor) BeforeSend(producer Producer, message *ProducerMessage) {}
func (noopProduceInterceptor) BeforeSend(ctx context.Context, _ Producer, _ *ProducerMessage) context.Context {
return ctx
}

func (noopProduceInterceptor) OnSendAcknowledgement(producer Producer, message *ProducerMessage, msgID MessageID) {
func (noopProduceInterceptor) OnSendAcknowledgement(_ context.Context, _ Producer, _ *ProducerMessage, _ MessageID) {
}

// copyPropertyIntercepotr copy all keys in message properties map and add a suffix
type metricProduceInterceptor struct {
sendn int
ackn int
type trackingProduceInterceptor struct {
sendn int
ackn int
maxDuration time.Duration
}

func (x *metricProduceInterceptor) BeforeSend(producer Producer, message *ProducerMessage) {
x.sendn++
type beforeSendCtxKey struct{}

func (i *trackingProduceInterceptor) BeforeSend(ctx context.Context, _ Producer, msg *ProducerMessage) context.Context {
i.sendn++
ctx = context.WithValue(ctx, beforeSendCtxKey{}, time.Now())
return ctx
}

func (x *metricProduceInterceptor) OnSendAcknowledgement(producer Producer, message *ProducerMessage, msgID MessageID) {
x.ackn++
func (i *trackingProduceInterceptor) OnSendAcknowledgement(ctx context.Context, _ Producer, _ *ProducerMessage, _ MessageID) {
var dur time.Duration
if v := ctx.Value(beforeSendCtxKey{}); v != nil {
dur = time.Since(v.(time.Time))
}

if dur > i.maxDuration {
i.maxDuration = dur
}

i.ackn++
}

func TestProducerWithInterceptors(t *testing.T) {
Expand All @@ -1518,14 +1533,14 @@ func TestProducerWithInterceptors(t *testing.T) {
assert.Nil(t, err)
defer consumer.Close()

metric := &metricProduceInterceptor{}
interceptor := &trackingProduceInterceptor{}
// create producer
producer, err := client.CreateProducer(ProducerOptions{
Topic: topic,
DisableBatching: false,
Interceptors: ProducerInterceptors{
noopProduceInterceptor{},
metric,
interceptor,
},
})
assert.Nil(t, err)
Expand Down Expand Up @@ -1575,8 +1590,9 @@ func TestProducerWithInterceptors(t *testing.T) {
consumer.Ack(msg)
}

assert.Equal(t, 10, metric.sendn)
assert.Equal(t, 10, metric.ackn)
assert.Equal(t, 10, interceptor.sendn)
assert.Equal(t, 10, interceptor.ackn)
assert.NotZero(t, interceptor.maxDuration)
}

func TestProducerSendAfterClose(t *testing.T) {
Expand Down Expand Up @@ -1719,7 +1735,7 @@ func TestMultipleSchemaOfKeyBasedBatchProducerConsumer(t *testing.T) {
}
producer.Flush()

//// create consumer
// create consumer
consumer, err := client.Subscribe(ConsumerOptions{
Topic: topic,
SubscriptionName: "my-sub2",
Expand Down Expand Up @@ -1810,7 +1826,7 @@ func TestMultipleSchemaProducerConsumer(t *testing.T) {
}
producer.Flush()

//// create consumer
// create consumer
consumer, err := client.Subscribe(ConsumerOptions{
Topic: topic,
SubscriptionName: "my-sub2",
Expand Down

0 comments on commit 011d1f2

Please sign in to comment.