From d0a65d165fb8a2a1263f3236c98a59ade232bcd4 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Sep 2023 07:40:21 +0300 Subject: [PATCH 1/5] feat(parser): add `x-ogen-custom-security` extension --- openapi/parser/parse_security.go | 37 ++++++++++++++++++++++---------- openapi/security.go | 2 ++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/openapi/parser/parse_security.go b/openapi/parser/parse_security.go index ed2ae12e9..dc31239d2 100644 --- a/openapi/parser/parse_security.go +++ b/openapi/parser/parse_security.go @@ -198,18 +198,33 @@ func (p *parser) parseSecurityRequirementScheme(name string, scheme *ogen.Securi if f := spec.Flows; f != nil { flows = *f } - security := openapi.Security{ - Type: spec.Type, - Description: spec.Description, - Name: spec.Name, - In: spec.In, - Scheme: spec.Scheme, - BearerFormat: spec.BearerFormat, - Flows: cloneOAuthFlows(flows, p.file(ctx)), - OpenIDConnectURL: spec.OpenIDConnectURL, - Pointer: spec.Common.Locator.Pointer(p.file(ctx)), + + var ( + custom bool + locator = spec.Common.Locator + ) + { + const extensionName = "x-ogen-custom-security" + if ex, ok := scheme.Common.Extensions[extensionName]; ok { + if err := ex.Decode(&custom); err != nil { + err := errors.Wrap(err, "unmarshal value") + return openapi.Security{}, p.wrapField(extensionName, p.file(ctx), locator, err) + } + } } - return security, nil + + return openapi.Security{ + Type: spec.Type, + Description: spec.Description, + Name: spec.Name, + In: spec.In, + Scheme: spec.Scheme, + BearerFormat: spec.BearerFormat, + Flows: cloneOAuthFlows(flows, p.file(ctx)), + OpenIDConnectURL: spec.OpenIDConnectURL, + XOgenCustomSecurity: custom, + Pointer: locator.Pointer(p.file(ctx)), + }, nil } func (p *parser) parseSecurityRequirements( diff --git a/openapi/security.go b/openapi/security.go index 1ae8abbc7..e90ee873c 100644 --- a/openapi/security.go +++ b/openapi/security.go @@ -33,6 +33,8 @@ type Security struct { Flows OAuthFlows OpenIDConnectURL string + XOgenCustomSecurity bool + location.Pointer `json:"-" yaml:"-"` } From d84da5c7cc2ce25349a751c209959e1bd4607f37 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Sep 2023 07:40:44 +0300 Subject: [PATCH 2/5] feat(gen): generate custom security handler --- gen/_template/security.tmpl | 71 +++++++++++++++++++++++-------------- gen/gen_security.go | 7 ++++ gen/ir/security.go | 8 +++++ 3 files changed, 60 insertions(+), 26 deletions(-) diff --git a/gen/_template/security.tmpl b/gen/_template/security.tmpl index d6dfedd41..c3a47c7bb 100644 --- a/gen/_template/security.tmpl +++ b/gen/_template/security.tmpl @@ -9,7 +9,11 @@ type SecurityHandler interface { {{- range $name, $s := $.Securities }} // Handle{{ $s.Type.Name }} handles {{ $name }} security. {{- template "godoc" $s.GoDoc }} + {{- if $s.Format.IsCustomSecurity }} + Handle{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) (context.Context, error) + {{- else }} Handle{{ $s.Type.Name }}(ctx context.Context, operationName string, t {{ $s.Type.Name }}) (context.Context, error) + {{- end }} {{- end }} } @@ -42,8 +46,8 @@ var oauth2Scopes = map[string][]string { {{- end }} func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) (context.Context, bool, error) { - var t {{ $s.Type.Name }} {{- if $s.Format.IsAPIKeySecurity }} + var t {{ $s.Type.Name }} const parameterName = {{ quote $s.ParameterName }} {{- if $s.Kind.IsHeader }} value := req.Header.Get(parameterName) @@ -71,6 +75,7 @@ func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName s {{- end }} t.APIKey = value {{- else if $s.Format.IsBasicHTTPSecurity }} + var t {{ $s.Type.Name }} if _, ok := findAuthorization(req.Header, "Basic"); !ok { return ctx, false, nil } @@ -81,18 +86,22 @@ func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName s t.Username = username t.Password = password {{- else if $s.Format.IsBearerSecurity }} + var t {{ $s.Type.Name }} token, ok := findAuthorization(req.Header, "Bearer") if !ok { return ctx, false, nil } t.Token = token {{- else if $s.Format.IsOAuth2Security }} + var t {{ $s.Type.Name }} token, ok := findAuthorization(req.Header, "Bearer") if !ok { return ctx, false, nil } t.Token = token t.Scopes = oauth2Scopes[operationName] + {{- else if $s.Format.IsCustomSecurity }} + t := req {{- else }} {{ errorf "unexpected security %q:%q" $s.Kind $s.Format }} {{- end }} @@ -113,39 +122,49 @@ type SecuritySource interface { {{- range $name, $s := $.Securities }} // {{ $s.Type.Name }} provides {{ $name }} security value. {{- template "godoc" $s.GoDoc }} + {{- if $s.Format.IsCustomSecurity }} + {{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) error + {{- else }} {{ $s.Type.Name }}(ctx context.Context, operationName string) ({{ $s.Type.Name }}, error) + {{- end }} {{- end }} } {{- range $s := $.Securities }} func (s *Client) security{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) error { - t, err := s.sec.{{ $s.Type.Name }}(ctx, operationName) - if err != nil { - return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }}) - } - {{- if $s.Format.IsAPIKeySecurity }} - {{- if $s.Kind.IsHeader }} - req.Header.Set({{ quote $s.ParameterName }}, t.APIKey) - {{- else if $s.Kind.IsQuery }} - q := req.URL.Query() - q.Set({{ quote $s.ParameterName }}, t.APIKey) - req.URL.RawQuery = q.Encode() - {{- else if $s.Kind.IsCookie }} - req.AddCookie(&http.Cookie{ - Name: {{ quote $s.ParameterName }}, - Value: t.APIKey, - }) + {{- if $s.Format.IsCustomSecurity }} + if err := s.sec.{{ $s.Type.Name }}(ctx, operationName, req); err != nil { + return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }}) + } + {{- else }} + t, err := s.sec.{{ $s.Type.Name }}(ctx, operationName) + if err != nil { + return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }}) + } + {{- if $s.Format.IsAPIKeySecurity }} + {{- if $s.Kind.IsHeader }} + req.Header.Set({{ quote $s.ParameterName }}, t.APIKey) + {{- else if $s.Kind.IsQuery }} + q := req.URL.Query() + q.Set({{ quote $s.ParameterName }}, t.APIKey) + req.URL.RawQuery = q.Encode() + {{- else if $s.Kind.IsCookie }} + req.AddCookie(&http.Cookie{ + Name: {{ quote $s.ParameterName }}, + Value: t.APIKey, + }) + {{- else }} + {{ errorf "unexpected security %q:%q" $s.Kind $s.Format }} + {{- end }} + {{- else if $s.Format.IsBasicHTTPSecurity }} + req.SetBasicAuth(t.Username, t.Password) + {{- else if $s.Format.IsBearerSecurity }} + req.Header.Set("Authorization", "Bearer " + t.Token) + {{- else if $s.Format.IsOAuth2Security }} + req.Header.Set("Authorization", "Bearer " + t.Token) {{- else }} - {{ errorf "unexpected security %q:%q" $s.Kind $s.Format }} + {{ errorf "unexpected security %q:%q" $s.Kind $s.Format }} {{- end }} - {{- else if $s.Format.IsBasicHTTPSecurity }} - req.SetBasicAuth(t.Username, t.Password) - {{- else if $s.Format.IsBearerSecurity }} - req.Header.Set("Authorization", "Bearer " + t.Token) - {{- else if $s.Format.IsOAuth2Security }} - req.Header.Set("Authorization", "Bearer " + t.Token) - {{- else }} - {{ errorf "unexpected security %q:%q" $s.Kind $s.Format }} {{- end }} return nil } diff --git a/gen/gen_security.go b/gen/gen_security.go index f2f51c27e..943aef827 100644 --- a/gen/gen_security.go +++ b/gen/gen_security.go @@ -113,6 +113,13 @@ func (g *Generator) generateSecurity(ctx *genctx, operationName string, spec ope Type: t, Description: security.Description, } + + // Do not create a type for custom security. + if security.XOgenCustomSecurity { + s.Format = ir.CustomSecurityFormat + return s, nil + } + defer func() { if rErr == nil { if err := ctx.saveType(t); err != nil { diff --git a/gen/ir/security.go b/gen/ir/security.go index eee193a28..72f561c97 100644 --- a/gen/ir/security.go +++ b/gen/ir/security.go @@ -57,6 +57,9 @@ const ( // Oauth2SecurityFormat is Oauth2 security format. Oauth2SecurityFormat SecurityFormat = "oauth2" + + // CustomSecurityFormat is a user-defined security format. + CustomSecurityFormat = "x-ogen-custom-security" ) // IsAPIKeySecurity whether s is APIKeySecurityFormat. @@ -84,6 +87,11 @@ func (s SecurityFormat) IsOAuth2Security() bool { return s == Oauth2SecurityFormat } +// IsCustomSecurity whether s is CustomSecurityFormat. +func (s SecurityFormat) IsCustomSecurity() bool { + return s == CustomSecurityFormat +} + type Security struct { Kind SecurityKind Format SecurityFormat From 17e2d1e7009e5c0e0aaec6baab4976024a06bd57 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Sep 2023 07:40:59 +0300 Subject: [PATCH 3/5] test(integration): add custom security test --- _testdata/positive/security.json | 22 +++++++++++++- internal/integration/security_test.go | 42 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/_testdata/positive/security.json b/_testdata/positive/security.json index 1a032ab05..07d098a41 100644 --- a/_testdata/positive/security.json +++ b/_testdata/positive/security.json @@ -60,6 +60,21 @@ } } } + }, + "/customSecurity": { + "get": { + "operationId": "customSecurity", + "security": [ + { + "custom": [] + } + ], + "responses": { + "200": { + "description": "OK" + } + } + } } }, "components": { @@ -86,7 +101,12 @@ "bearerToken": { "type": "http", "scheme": "bearer" + }, + "custom": { + "type": "http", + "scheme": "digest", + "x-ogen-custom-security": true } } } -} +} \ No newline at end of file diff --git a/internal/integration/security_test.go b/internal/integration/security_test.go index 346bfcf04..0c9a8db18 100644 --- a/internal/integration/security_test.go +++ b/internal/integration/security_test.go @@ -14,12 +14,15 @@ import ( "github.com/ogen-go/ogen/ogenerrors" ) +const customSecurityHeader = "X-Foo-Custom" + type testSecurity struct { basicAuth api.BasicAuth bearerToken api.BearerToken headerKey api.HeaderKey queryKey api.QueryKey cookieKey api.CookieKey + custom string } func (t *testSecurity) OptionalSecurity(ctx context.Context) error { @@ -34,6 +37,10 @@ func (t *testSecurity) IntersectSecurity(ctx context.Context) error { return nil } +func (t *testSecurity) CustomSecurity(ctx context.Context) error { + return nil +} + type tokenKey string func (t *testSecurity) HandleBasicAuth(ctx context.Context, operationName string, v api.BasicAuth) (context.Context, error) { @@ -71,12 +78,21 @@ func (t *testSecurity) HandleCookieKey(ctx context.Context, operationName string return context.WithValue(ctx, tokenKey("CookieKey"), v), nil } +func (t *testSecurity) HandleCustom(ctx context.Context, operationName string, req *http.Request) (context.Context, error) { + got := req.Header.Get(customSecurityHeader) + if got != t.custom { + return nil, errors.Errorf("invalid custom auth: %q", got) + } + return context.WithValue(ctx, tokenKey("Custom"), got), nil +} + type testSecuritySource struct { basicAuth *api.BasicAuth bearerToken *api.BearerToken headerKey *api.HeaderKey queryKey *api.QueryKey cookieKey *api.CookieKey + custom string } func (t *testSecuritySource) BasicAuth(ctx context.Context, operationName string) (r api.BasicAuth, _ error) { @@ -114,6 +130,14 @@ func (t *testSecuritySource) CookieKey(ctx context.Context, operationName string return r, ogenerrors.ErrSkipClientSecurity } +func (t *testSecuritySource) Custom(ctx context.Context, operationName string, req *http.Request) error { + if t.custom == "" { + return ogenerrors.ErrSkipClientSecurity + } + req.Header.Set(customSecurityHeader, t.custom) + return nil +} + func TestSecurity(t *testing.T) { h := &testSecurity{ basicAuth: api.BasicAuth{Username: "username", Password: "password"}, @@ -121,6 +145,7 @@ func TestSecurity(t *testing.T) { headerKey: api.HeaderKey{APIKey: "HeaderKey"}, queryKey: api.QueryKey{APIKey: "QueryKey"}, cookieKey: api.CookieKey{APIKey: "CookieKey"}, + custom: "foobar-custom-token", } srv, err := api.NewServer(h, h) require.NoError(t, err) @@ -135,6 +160,7 @@ func TestSecurity(t *testing.T) { bearerToken: &h.bearerToken, headerKey: &h.headerKey, queryKey: &h.queryKey, + custom: h.custom, }, api.WithClient(s.Client())) require.NoError(t, err) @@ -217,6 +243,22 @@ func TestSecurity(t *testing.T) { require.NoError(t, client.IntersectSecurity(context.Background())) }) + t.Run("CustomSecurity", func(t *testing.T) { + resp := sendReq(t, "/customSecurity", nil) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + resp = sendReq(t, "/customSecurity", func(r *http.Request) { + r.Header.Set(customSecurityHeader, "wrong-token") + }) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + resp = sendReq(t, "/customSecurity", func(r *http.Request) { + r.Header.Set(customSecurityHeader, h.custom) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + + require.NoError(t, client.CustomSecurity(context.Background())) + }) } func TestSecurityClientCheck(t *testing.T) { From b52b03ccac6e2075633f19653844a315e6332c67 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Sep 2023 07:41:39 +0300 Subject: [PATCH 4/5] chore(validate): remove debug line --- validate/array_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/validate/array_test.go b/validate/array_test.go index e64b07118..29b41051c 100644 --- a/validate/array_test.go +++ b/validate/array_test.go @@ -115,7 +115,6 @@ func TestUniqueItems(t *testing.T) { } { tt := tt t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { - fmt.Println("Test", i) check := require.NoError if tt.WantErr { check = require.Error From 4135e32ea43cdc8955a139ea9c172d3a50ddd624 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Mon, 25 Sep 2023 07:41:47 +0300 Subject: [PATCH 5/5] chore: commit generated files --- .../test_security/oas_client_gen.go | 107 ++++++++++++++ .../test_security/oas_handlers_gen.go | 133 ++++++++++++++++++ .../oas_response_decoders_gen.go | 9 ++ .../oas_response_encoders_gen.go | 7 + .../test_security/oas_router_gen.go | 40 ++++++ .../test_security/oas_schemas_gen.go | 3 + .../test_security/oas_security_gen.go | 20 +++ .../test_security/oas_server_gen.go | 4 + .../test_security/oas_unimplemented_gen.go | 7 + 9 files changed, 330 insertions(+) diff --git a/internal/integration/test_security/oas_client_gen.go b/internal/integration/test_security/oas_client_gen.go index 6d25b9edc..561f28620 100644 --- a/internal/integration/test_security/oas_client_gen.go +++ b/internal/integration/test_security/oas_client_gen.go @@ -23,6 +23,10 @@ import ( // Invoker invokes operations described by OpenAPI v3 specification. type Invoker interface { + // CustomSecurity invokes customSecurity operation. + // + // GET /customSecurity + CustomSecurity(ctx context.Context) error // DisjointSecurity invokes disjointSecurity operation. // // GET /disjointSecurity @@ -87,6 +91,109 @@ func (c *Client) requestURL(ctx context.Context) *url.URL { return u } +// CustomSecurity invokes customSecurity operation. +// +// GET /customSecurity +func (c *Client) CustomSecurity(ctx context.Context) error { + _, err := c.sendCustomSecurity(ctx) + return err +} + +func (c *Client) sendCustomSecurity(ctx context.Context) (res *CustomSecurityOK, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("customSecurity"), + semconv.HTTPMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/customSecurity"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, "CustomSecurity", + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/customSecurity" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + { + type bitset = [1]uint8 + var satisfied bitset + { + stage = "Security:Custom" + switch err := c.securityCustom(ctx, "CustomSecurity", r); { + case err == nil: // if NO error + satisfied[0] |= 1 << 0 + case errors.Is(err, ogenerrors.ErrSkipClientSecurity): + // Skip this security. + default: + return res, errors.Wrap(err, "security \"Custom\"") + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + return res, ogenerrors.ErrSecurityRequirementIsNotSatisfied + } + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeCustomSecurityResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + // DisjointSecurity invokes disjointSecurity operation. // // GET /disjointSecurity diff --git a/internal/integration/test_security/oas_handlers_gen.go b/internal/integration/test_security/oas_handlers_gen.go index e61cd9bab..e46bdd57d 100644 --- a/internal/integration/test_security/oas_handlers_gen.go +++ b/internal/integration/test_security/oas_handlers_gen.go @@ -20,6 +20,139 @@ import ( "github.com/ogen-go/ogen/otelogen" ) +// handleCustomSecurityRequest handles customSecurity operation. +// +// GET /customSecurity +func (s *Server) handleCustomSecurityRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("customSecurity"), + semconv.HTTPMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/customSecurity"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), "CustomSecurity", + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + s.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + s.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: "CustomSecurity", + ID: "customSecurity", + } + ) + { + type bitset = [1]uint8 + var satisfied bitset + { + sctx, ok, err := s.securityCustom(ctx, "CustomSecurity", r) + if err != nil { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Security: "Custom", + Err: err, + } + recordError("Security:Custom", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if ok { + satisfied[0] |= 1 << 0 + ctx = sctx + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Err: ogenerrors.ErrSecurityRequirementIsNotSatisfied, + } + recordError("Security", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + } + + var response *CustomSecurityOK + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: "CustomSecurity", + OperationSummary: "", + OperationID: "customSecurity", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = *CustomSecurityOK + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + err = s.h.CustomSecurity(ctx) + return response, err + }, + ) + } else { + err = s.h.CustomSecurity(ctx) + } + if err != nil { + recordError("Internal", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + if err := encodeCustomSecurityResponse(response, w, span); err != nil { + recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + // handleDisjointSecurityRequest handles disjointSecurity operation. // // GET /disjointSecurity diff --git a/internal/integration/test_security/oas_response_decoders_gen.go b/internal/integration/test_security/oas_response_decoders_gen.go index 440bf4b0e..4da9041cb 100644 --- a/internal/integration/test_security/oas_response_decoders_gen.go +++ b/internal/integration/test_security/oas_response_decoders_gen.go @@ -8,6 +8,15 @@ import ( "github.com/ogen-go/ogen/validate" ) +func decodeCustomSecurityResponse(resp *http.Response) (res *CustomSecurityOK, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + return &CustomSecurityOK{}, nil + } + return res, validate.UnexpectedStatusCode(resp.StatusCode) +} + func decodeDisjointSecurityResponse(resp *http.Response) (res *DisjointSecurityOK, _ error) { switch resp.StatusCode { case 200: diff --git a/internal/integration/test_security/oas_response_encoders_gen.go b/internal/integration/test_security/oas_response_encoders_gen.go index 547bd716f..369e0ca78 100644 --- a/internal/integration/test_security/oas_response_encoders_gen.go +++ b/internal/integration/test_security/oas_response_encoders_gen.go @@ -9,6 +9,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +func encodeCustomSecurityResponse(response *CustomSecurityOK, w http.ResponseWriter, span trace.Span) error { + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + return nil +} + func encodeDisjointSecurityResponse(response *DisjointSecurityOK, w http.ResponseWriter, span trace.Span) error { w.WriteHeader(200) span.SetStatus(codes.Ok, http.StatusText(200)) diff --git a/internal/integration/test_security/oas_router_gen.go b/internal/integration/test_security/oas_router_gen.go index 45326b85d..6049d712a 100644 --- a/internal/integration/test_security/oas_router_gen.go +++ b/internal/integration/test_security/oas_router_gen.go @@ -59,6 +59,24 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { break } switch elem[0] { + case 'c': // Prefix: "customSecurity" + if l := len("customSecurity"); len(elem) >= l && elem[0:l] == "customSecurity" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleCustomSecurityRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } case 'd': // Prefix: "disjointSecurity" if l := len("disjointSecurity"); len(elem) >= l && elem[0:l] == "disjointSecurity" { elem = elem[l:] @@ -205,6 +223,28 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { break } switch elem[0] { + case 'c': // Prefix: "customSecurity" + if l := len("customSecurity"); len(elem) >= l && elem[0:l] == "customSecurity" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + switch method { + case "GET": + // Leaf: CustomSecurity + r.name = "CustomSecurity" + r.summary = "" + r.operationID = "customSecurity" + r.pathPattern = "/customSecurity" + r.args = args + r.count = 0 + return r, true + default: + return + } + } case 'd': // Prefix: "disjointSecurity" if l := len("disjointSecurity"); len(elem) >= l && elem[0:l] == "disjointSecurity" { elem = elem[l:] diff --git a/internal/integration/test_security/oas_schemas_gen.go b/internal/integration/test_security/oas_schemas_gen.go index 808a1f528..fb9df6c00 100644 --- a/internal/integration/test_security/oas_schemas_gen.go +++ b/internal/integration/test_security/oas_schemas_gen.go @@ -55,6 +55,9 @@ func (s *CookieKey) SetAPIKey(val string) { s.APIKey = val } +// CustomSecurityOK is response for CustomSecurity operation. +type CustomSecurityOK struct{} + // DisjointSecurityOK is response for DisjointSecurity operation. type DisjointSecurityOK struct{} diff --git a/internal/integration/test_security/oas_security_gen.go b/internal/integration/test_security/oas_security_gen.go index da70a1412..f8992df99 100644 --- a/internal/integration/test_security/oas_security_gen.go +++ b/internal/integration/test_security/oas_security_gen.go @@ -20,6 +20,8 @@ type SecurityHandler interface { HandleBearerToken(ctx context.Context, operationName string, t BearerToken) (context.Context, error) // HandleCookieKey handles cookieKey security. HandleCookieKey(ctx context.Context, operationName string, t CookieKey) (context.Context, error) + // HandleCustom handles custom security. + HandleCustom(ctx context.Context, operationName string, req *http.Request) (context.Context, error) // HandleHeaderKey handles headerKey security. HandleHeaderKey(ctx context.Context, operationName string, t HeaderKey) (context.Context, error) // HandleQueryKey handles queryKey security. @@ -96,6 +98,16 @@ func (s *Server) securityCookieKey(ctx context.Context, operationName string, re } return rctx, true, err } +func (s *Server) securityCustom(ctx context.Context, operationName string, req *http.Request) (context.Context, bool, error) { + t := req + rctx, err := s.sec.HandleCustom(ctx, operationName, t) + if errors.Is(err, ogenerrors.ErrSkipServerSecurity) { + return nil, false, nil + } else if err != nil { + return nil, false, err + } + return rctx, true, err +} func (s *Server) securityHeaderKey(ctx context.Context, operationName string, req *http.Request) (context.Context, bool, error) { var t HeaderKey const parameterName = "X-Api-Key" @@ -138,6 +150,8 @@ type SecuritySource interface { BearerToken(ctx context.Context, operationName string) (BearerToken, error) // CookieKey provides cookieKey security value. CookieKey(ctx context.Context, operationName string) (CookieKey, error) + // Custom provides custom security value. + Custom(ctx context.Context, operationName string, req *http.Request) error // HeaderKey provides headerKey security value. HeaderKey(ctx context.Context, operationName string) (HeaderKey, error) // QueryKey provides queryKey security value. @@ -171,6 +185,12 @@ func (s *Client) securityCookieKey(ctx context.Context, operationName string, re }) return nil } +func (s *Client) securityCustom(ctx context.Context, operationName string, req *http.Request) error { + if err := s.sec.Custom(ctx, operationName, req); err != nil { + return errors.Wrap(err, "security source \"Custom\"") + } + return nil +} func (s *Client) securityHeaderKey(ctx context.Context, operationName string, req *http.Request) error { t, err := s.sec.HeaderKey(ctx, operationName) if err != nil { diff --git a/internal/integration/test_security/oas_server_gen.go b/internal/integration/test_security/oas_server_gen.go index dc189325a..02974740c 100644 --- a/internal/integration/test_security/oas_server_gen.go +++ b/internal/integration/test_security/oas_server_gen.go @@ -8,6 +8,10 @@ import ( // Handler handles operations described by OpenAPI v3 specification. type Handler interface { + // CustomSecurity implements customSecurity operation. + // + // GET /customSecurity + CustomSecurity(ctx context.Context) error // DisjointSecurity implements disjointSecurity operation. // // GET /disjointSecurity diff --git a/internal/integration/test_security/oas_unimplemented_gen.go b/internal/integration/test_security/oas_unimplemented_gen.go index 0336bbba9..06dc18c8b 100644 --- a/internal/integration/test_security/oas_unimplemented_gen.go +++ b/internal/integration/test_security/oas_unimplemented_gen.go @@ -13,6 +13,13 @@ type UnimplementedHandler struct{} var _ Handler = UnimplementedHandler{} +// CustomSecurity implements customSecurity operation. +// +// GET /customSecurity +func (UnimplementedHandler) CustomSecurity(ctx context.Context) error { + return ht.ErrNotImplemented +} + // DisjointSecurity implements disjointSecurity operation. // // GET /disjointSecurity