diff --git a/pkg/model/trace/combine.go b/pkg/model/trace/combine.go index e21720ef286..d68c338472b 100644 --- a/pkg/model/trace/combine.go +++ b/pkg/model/trace/combine.go @@ -5,6 +5,7 @@ import ( "fmt" "hash" "hash/fnv" + "sync" "github.com/grafana/tempo/pkg/tempopb" ) @@ -41,6 +42,7 @@ var ErrTraceTooLarge = fmt.Errorf("trace exceeds max size") // * Only sort the final result once and if needed. // * Don't scan/hash the spans for the last input (final=true). type Combiner struct { + mtx sync.Mutex result *tempopb.Trace spans map[token]struct{} combined bool @@ -53,6 +55,7 @@ type Combiner struct { // when allowPartialTrace is set to true a partial trace that exceed the max size may be returned func NewCombiner(maxSizeBytes int, allowPartialTrace bool) *Combiner { return &Combiner{ + mtx: sync.Mutex{}, maxSizeBytes: maxSizeBytes, allowPartialTrace: allowPartialTrace, } @@ -66,6 +69,9 @@ func (c *Combiner) Consume(tr *tempopb.Trace) (int, error) { // ConsumeWithFinal consumes the trace, but allows for performance savings when // it is known that this is the last expected input trace. func (c *Combiner) ConsumeWithFinal(tr *tempopb.Trace, final bool) (int, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + var spanCount int if tr == nil || c.IsPartialTrace() { return spanCount, nil diff --git a/pkg/model/trace/combine_test.go b/pkg/model/trace/combine_test.go index bb23ee1f6a5..dda45a6b0c9 100644 --- a/pkg/model/trace/combine_test.go +++ b/pkg/model/trace/combine_test.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" "strconv" + "sync" "testing" "github.com/grafana/tempo/pkg/tempopb" @@ -104,6 +105,23 @@ func TestCombinerReturnsAPartialTrace(t *testing.T) { } } +func TestCombinerParallel(t *testing.T) { + // Ensure that the combiner is safe for parallel use. + c := NewCombiner(0, false) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + _, err := c.Consume(test.MakeTraceWithSpanCount(1, 1, []byte{0x01})) + require.NoError(t, err) + } + }() + } + wg.Wait() +} + func TestTokenForIDCollision(t *testing.T) { // Estimate the hash collision rate of tokenForID.