From 0e9ca4b2790c58e1227a56df797052dd25934fb1 Mon Sep 17 00:00:00 2001 From: aereal Date: Wed, 13 Dec 2023 21:14:52 +0900 Subject: [PATCH 1/2] test: load schema file at most first --- middleware_test.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/middleware_test.go b/middleware_test.go index 3e70fcc..d3e1e55 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -22,22 +22,26 @@ import ( "github.com/getkin/kin-openapi/routers/gorillamux" ) -type user struct { - Name string `json:"name"` - ID string `json:"id"` - Age int `json:"age"` -} +var router routers.Router -func TestWithValidation(t *testing.T) { +func init() { doc, err := openapi3.NewLoader().LoadFromFile("./testdata/user-account-service.openapi.json") if err != nil { - t.Fatal(err) + panic(err) } - router, err := gorillamux.NewRouter(doc) + router, err = gorillamux.NewRouter(doc) if err != nil { - t.Fatal(err) + panic(err) } +} + +type user struct { + Name string `json:"name"` + ID string `json:"id"` + Age int `json:"age"` +} +func TestWithValidation(t *testing.T) { testCases := []struct { name string handler http.Handler From 721bc0e1b5bc5b7045a7101a902e69a8633069af Mon Sep 17 00:00:00 2001 From: aereal Date: Wed, 13 Dec 2023 21:21:44 +0900 Subject: [PATCH 2/2] feat: support OpenTelemetry tracing --- go.mod | 7 +++- go.sum | 20 ++++++++++- middleware.go | 40 +++++++++++++++++++--- middleware_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index f4c150b..0b66119 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module github.com/aereal/go-openapi3-validation-middleware go 1.16 -require github.com/getkin/kin-openapi v0.122.0 +require ( + github.com/getkin/kin-openapi v0.122.0 + go.opentelemetry.io/otel v1.21.0 + go.opentelemetry.io/otel/sdk v1.21.0 + go.opentelemetry.io/otel/trace v1.21.0 +) diff --git a/go.sum b/go.sum index c064a4e..247c738 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/getkin/kin-openapi v0.122.0 h1:WB9Jbl0Hp/T79/JF9xlSW5Kl9uYdk/AWD0yAd9HOM10= github.com/getkin/kin-openapi v0.122.0/go.mod h1:PCWw/lfBrJY4HcdqE3jj+QFkaFK8ABoqo7PvqVhXXqw= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= +github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= @@ -11,6 +16,8 @@ github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogB github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= @@ -37,12 +44,23 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc= +go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= +go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4= +go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= +go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= +go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= +go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc= +go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/middleware.go b/middleware.go index ef2c9c9..c460b5d 100644 --- a/middleware.go +++ b/middleware.go @@ -1,6 +1,7 @@ package openapi3middleware import ( + "context" "encoding/json" "errors" "fmt" @@ -9,6 +10,8 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" "github.com/getkin/kin-openapi/routers" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" ) type middleware = func(next http.Handler) http.Handler @@ -19,6 +22,7 @@ type MiddlewareOptions struct { ReportFindRouteError func(w http.ResponseWriter, r *http.Request, err error) ReportRequestValidationError func(w http.ResponseWriter, r *http.Request, err error) ReportResponseValidationError func(w http.ResponseWriter, r *http.Request, err error) + TracerProvider trace.TracerProvider } func (o MiddlewareOptions) reportFindRouteError(w http.ResponseWriter, r *http.Request, err error) { @@ -60,13 +64,18 @@ func WithResponseValidation(options MiddlewareOptions) middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ctx, span := getTracer(ctx, options).Start(ctx, "ResponseValidation") + defer span.End() irw := newBufferingResponseWriter(w) - next.ServeHTTP(irw, r) + next.ServeHTTP(irw, r.WithContext(ctx)) ri, err := buildRequestValidationInputFromRequest(options.Router, r, options.ValidationOptions) if frErr := new(findRouteErr); errors.As(err, &frErr) { - options.reportFindRouteError(w, r, frErr.Unwrap()) + actualErr := frErr.Unwrap() + span.RecordError(actualErr) + options.reportFindRouteError(w, r, actualErr) return } else if err != nil { + span.RecordError(err) respondErrorJSON(w, http.StatusInternalServerError, err) return } @@ -81,6 +90,7 @@ func WithResponseValidation(options MiddlewareOptions) middleware { bodyBytes := irw.buf.Bytes() input.SetBodyBytes(bodyBytes) if err := openapi3filter.ValidateResponse(ctx, input); err != nil { + span.RecordError(err) options.reportRespError(w, r, err) return } @@ -94,20 +104,26 @@ func WithResponseValidation(options MiddlewareOptions) middleware { func WithRequestValidation(options MiddlewareOptions) middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx, span := getTracer(ctx, options).Start(ctx, "RequestValidation") + defer span.End() input, err := buildRequestValidationInputFromRequest(options.Router, r, options.ValidationOptions) if frErr := new(findRouteErr); errors.As(err, &frErr) { - options.reportFindRouteError(w, r, frErr.Unwrap()) + actualErr := frErr.Unwrap() + span.RecordError(actualErr) + options.reportFindRouteError(w, r, actualErr) return } else if err != nil { + span.RecordError(err) respondErrorJSON(w, http.StatusInternalServerError, err) return } - ctx := r.Context() if err := openapi3filter.ValidateRequest(ctx, input); err != nil { + span.RecordError(err) options.reportReqError(w, r, err) return } - next.ServeHTTP(w, r) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } @@ -218,3 +234,17 @@ func respondJSON(w http.ResponseWriter, statusCode int, payload interface{}) err w.WriteHeader(statusCode) return json.NewEncoder(w).Encode(payload) } + +const tracerName = "github.com/aereal/go-openapi3-validation-middleware" + +func getTracer(ctx context.Context, opts MiddlewareOptions) trace.Tracer { + tp := opts.TracerProvider + if tp == nil { + if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { + tp = span.TracerProvider() + } else { + tp = otel.GetTracerProvider() + } + } + return tp.Tracer(tracerName) +} diff --git a/middleware_test.go b/middleware_test.go index d3e1e55..a0feb3a 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -3,6 +3,7 @@ package openapi3middleware import ( "bufio" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -20,6 +21,11 @@ import ( "github.com/getkin/kin-openapi/openapi3filter" "github.com/getkin/kin-openapi/routers" "github.com/getkin/kin-openapi/routers/gorillamux" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" ) var router routers.Router @@ -182,6 +188,82 @@ func TestWithValidation(t *testing.T) { } } +func TestWithValidation_otel(t *testing.T) { + testCases := []struct { + name string + buildOptions func(tp trace.TracerProvider) MiddlewareOptions + wantSpans int + }{ + { + name: "ok/explicitly passing TracerProvider", + buildOptions: func(tp trace.TracerProvider) MiddlewareOptions { + return MiddlewareOptions{ + Router: router, + TracerProvider: tp, + } + }, + wantSpans: 3, + }, + { + name: "ok/use TracerProvider comes from the current span", + buildOptions: func(_ trace.TracerProvider) MiddlewareOptions { + return MiddlewareOptions{Router: router} + }, + wantSpans: 3, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + if deadline, ok := t.Deadline(); ok { + ctx, cancel = context.WithDeadline(ctx, deadline) + } + defer cancel() + + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithBatcher(exporter)) + + withOtel := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + ctx, span := tp.Tracer("test").Start(ctx, fmt.Sprintf("%s %s", r.Method, r.URL.Path)) + defer span.End() + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + _ = json.NewEncoder(w).Encode(user{Name: "aereal", Age: 17, ID: "123"}) + }) + srv := httptest.NewServer(withOtel(WithValidation(tc.buildOptions(tp))(handler))) + defer srv.Close() + + req := mustRequest(newRequest(http.MethodPost, srv.URL+"/users", map[string]string{"content-type": "application/json"}, `{"name":"aereal","age":17}`)) + resp, err := srv.Client().Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("http.Client.Do: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("unexpected status code: %d", resp.StatusCode) + } + + if err := tp.ForceFlush(ctx); err != nil { + t.Fatal(err) + } + spans := exporter.GetSpans() + t.Logf("%d spans got", len(spans)) + for i, span := range spans { + t.Logf("#%d: %#v", i, span) + } + if len(spans) != tc.wantSpans { + t.Errorf("spans count:\nwant: %d\ngot: %d", tc.wantSpans, len(spans)) + } + }) + } +} + func resumeResponse(testName string, got *http.Response) (*http.Response, error) { imported, err := importResponse(testName) if err == nil {