Skip to content

Commit

Permalink
feat: support OpenTelemetry tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
aereal committed Dec 13, 2023
1 parent 0e9ca4b commit 721bc0e
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 7 deletions.
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ module github.com/aereal/go-openapi3-validation-middleware

go 1.16

require github.com/getkin/kin-openapi v0.122.0
require (
github.com/getkin/kin-openapi v0.122.0
go.opentelemetry.io/otel v1.21.0
go.opentelemetry.io/otel/sdk v1.21.0
go.opentelemetry.io/otel/trace v1.21.0
)
20 changes: 19 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/getkin/kin-openapi v0.122.0 h1:WB9Jbl0Hp/T79/JF9xlSW5Kl9uYdk/AWD0yAd9HOM10=
github.com/getkin/kin-openapi v0.122.0/go.mod h1:PCWw/lfBrJY4HcdqE3jj+QFkaFK8ABoqo7PvqVhXXqw=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
Expand All @@ -37,12 +44,23 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo=
go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4=
go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM=
go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc=
go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Expand Down
40 changes: 35 additions & 5 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openapi3middleware

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -9,6 +10,8 @@ import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
)

type middleware = func(next http.Handler) http.Handler
Expand All @@ -19,6 +22,7 @@ type MiddlewareOptions struct {
ReportFindRouteError func(w http.ResponseWriter, r *http.Request, err error)
ReportRequestValidationError func(w http.ResponseWriter, r *http.Request, err error)
ReportResponseValidationError func(w http.ResponseWriter, r *http.Request, err error)
TracerProvider trace.TracerProvider
}

func (o MiddlewareOptions) reportFindRouteError(w http.ResponseWriter, r *http.Request, err error) {
Expand Down Expand Up @@ -60,13 +64,18 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, span := getTracer(ctx, options).Start(ctx, "ResponseValidation")
defer span.End()
irw := newBufferingResponseWriter(w)
next.ServeHTTP(irw, r)
next.ServeHTTP(irw, r.WithContext(ctx))
ri, err := buildRequestValidationInputFromRequest(options.Router, r, options.ValidationOptions)
if frErr := new(findRouteErr); errors.As(err, &frErr) {
options.reportFindRouteError(w, r, frErr.Unwrap())
actualErr := frErr.Unwrap()
span.RecordError(actualErr)
options.reportFindRouteError(w, r, actualErr)
return
} else if err != nil {
span.RecordError(err)
respondErrorJSON(w, http.StatusInternalServerError, err)
return
}
Expand All @@ -81,6 +90,7 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
bodyBytes := irw.buf.Bytes()
input.SetBodyBytes(bodyBytes)
if err := openapi3filter.ValidateResponse(ctx, input); err != nil {
span.RecordError(err)
options.reportRespError(w, r, err)
return
}
Expand All @@ -94,20 +104,26 @@ func WithResponseValidation(options MiddlewareOptions) middleware {
func WithRequestValidation(options MiddlewareOptions) middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, span := getTracer(ctx, options).Start(ctx, "RequestValidation")
defer span.End()
input, err := buildRequestValidationInputFromRequest(options.Router, r, options.ValidationOptions)
if frErr := new(findRouteErr); errors.As(err, &frErr) {
options.reportFindRouteError(w, r, frErr.Unwrap())
actualErr := frErr.Unwrap()
span.RecordError(actualErr)
options.reportFindRouteError(w, r, actualErr)
return
} else if err != nil {
span.RecordError(err)
respondErrorJSON(w, http.StatusInternalServerError, err)
return
}
ctx := r.Context()
if err := openapi3filter.ValidateRequest(ctx, input); err != nil {
span.RecordError(err)
options.reportReqError(w, r, err)
return
}
next.ServeHTTP(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
Expand Down Expand Up @@ -218,3 +234,17 @@ func respondJSON(w http.ResponseWriter, statusCode int, payload interface{}) err
w.WriteHeader(statusCode)
return json.NewEncoder(w).Encode(payload)
}

const tracerName = "github.com/aereal/go-openapi3-validation-middleware"

func getTracer(ctx context.Context, opts MiddlewareOptions) trace.Tracer {
tp := opts.TracerProvider
if tp == nil {
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
tp = span.TracerProvider()
} else {
tp = otel.GetTracerProvider()
}
}
return tp.Tracer(tracerName)
}
82 changes: 82 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openapi3middleware
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -20,6 +21,11 @@ import (
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers"
"github.com/getkin/kin-openapi/routers/gorillamux"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"go.opentelemetry.io/otel/trace"
)

var router routers.Router
Expand Down Expand Up @@ -182,6 +188,82 @@ func TestWithValidation(t *testing.T) {
}
}

func TestWithValidation_otel(t *testing.T) {
testCases := []struct {
name string
buildOptions func(tp trace.TracerProvider) MiddlewareOptions
wantSpans int
}{
{
name: "ok/explicitly passing TracerProvider",
buildOptions: func(tp trace.TracerProvider) MiddlewareOptions {
return MiddlewareOptions{
Router: router,
TracerProvider: tp,
}
},
wantSpans: 3,
},
{
name: "ok/use TracerProvider comes from the current span",
buildOptions: func(_ trace.TracerProvider) MiddlewareOptions {
return MiddlewareOptions{Router: router}
},
wantSpans: 3,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
if deadline, ok := t.Deadline(); ok {
ctx, cancel = context.WithDeadline(ctx, deadline)
}
defer cancel()

exporter := tracetest.NewInMemoryExporter()
tp := sdktrace.NewTracerProvider(sdktrace.WithBatcher(exporter))

withOtel := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
ctx, span := tp.Tracer("test").Start(ctx, fmt.Sprintf("%s %s", r.Method, r.URL.Path))
defer span.End()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
_ = json.NewEncoder(w).Encode(user{Name: "aereal", Age: 17, ID: "123"})
})
srv := httptest.NewServer(withOtel(WithValidation(tc.buildOptions(tp))(handler)))
defer srv.Close()

req := mustRequest(newRequest(http.MethodPost, srv.URL+"/users", map[string]string{"content-type": "application/json"}, `{"name":"aereal","age":17}`))
resp, err := srv.Client().Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("http.Client.Do: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status code: %d", resp.StatusCode)
}

if err := tp.ForceFlush(ctx); err != nil {
t.Fatal(err)
}
spans := exporter.GetSpans()
t.Logf("%d spans got", len(spans))
for i, span := range spans {
t.Logf("#%d: %#v", i, span)
}
if len(spans) != tc.wantSpans {
t.Errorf("spans count:\nwant: %d\ngot: %d", tc.wantSpans, len(spans))
}
})
}
}

func resumeResponse(testName string, got *http.Response) (*http.Response, error) {
imported, err := importResponse(testName)
if err == nil {
Expand Down

0 comments on commit 721bc0e

Please sign in to comment.