diff --git a/_testdata/positive/security.json b/_testdata/positive/security.json index 1a032ab05..07d098a41 100644 --- a/_testdata/positive/security.json +++ b/_testdata/positive/security.json @@ -60,6 +60,21 @@ } } } + }, + "/customSecurity": { + "get": { + "operationId": "customSecurity", + "security": [ + { + "custom": [] + } + ], + "responses": { + "200": { + "description": "OK" + } + } + } } }, "components": { @@ -86,7 +101,12 @@ "bearerToken": { "type": "http", "scheme": "bearer" + }, + "custom": { + "type": "http", + "scheme": "digest", + "x-ogen-custom-security": true } } } -} +} \ No newline at end of file diff --git a/internal/integration/security_test.go b/internal/integration/security_test.go index 346bfcf04..0c9a8db18 100644 --- a/internal/integration/security_test.go +++ b/internal/integration/security_test.go @@ -14,12 +14,15 @@ import ( "github.com/ogen-go/ogen/ogenerrors" ) +const customSecurityHeader = "X-Foo-Custom" + type testSecurity struct { basicAuth api.BasicAuth bearerToken api.BearerToken headerKey api.HeaderKey queryKey api.QueryKey cookieKey api.CookieKey + custom string } func (t *testSecurity) OptionalSecurity(ctx context.Context) error { @@ -34,6 +37,10 @@ func (t *testSecurity) IntersectSecurity(ctx context.Context) error { return nil } +func (t *testSecurity) CustomSecurity(ctx context.Context) error { + return nil +} + type tokenKey string func (t *testSecurity) HandleBasicAuth(ctx context.Context, operationName string, v api.BasicAuth) (context.Context, error) { @@ -71,12 +78,21 @@ func (t *testSecurity) HandleCookieKey(ctx context.Context, operationName string return context.WithValue(ctx, tokenKey("CookieKey"), v), nil } +func (t *testSecurity) HandleCustom(ctx context.Context, operationName string, req *http.Request) (context.Context, error) { + got := req.Header.Get(customSecurityHeader) + if got != t.custom { + return nil, errors.Errorf("invalid custom auth: %q", got) + } + return context.WithValue(ctx, tokenKey("Custom"), got), nil +} + type testSecuritySource struct { basicAuth *api.BasicAuth bearerToken *api.BearerToken headerKey *api.HeaderKey queryKey *api.QueryKey cookieKey *api.CookieKey + custom string } func (t *testSecuritySource) BasicAuth(ctx context.Context, operationName string) (r api.BasicAuth, _ error) { @@ -114,6 +130,14 @@ func (t *testSecuritySource) CookieKey(ctx context.Context, operationName string return r, ogenerrors.ErrSkipClientSecurity } +func (t *testSecuritySource) Custom(ctx context.Context, operationName string, req *http.Request) error { + if t.custom == "" { + return ogenerrors.ErrSkipClientSecurity + } + req.Header.Set(customSecurityHeader, t.custom) + return nil +} + func TestSecurity(t *testing.T) { h := &testSecurity{ basicAuth: api.BasicAuth{Username: "username", Password: "password"}, @@ -121,6 +145,7 @@ func TestSecurity(t *testing.T) { headerKey: api.HeaderKey{APIKey: "HeaderKey"}, queryKey: api.QueryKey{APIKey: "QueryKey"}, cookieKey: api.CookieKey{APIKey: "CookieKey"}, + custom: "foobar-custom-token", } srv, err := api.NewServer(h, h) require.NoError(t, err) @@ -135,6 +160,7 @@ func TestSecurity(t *testing.T) { bearerToken: &h.bearerToken, headerKey: &h.headerKey, queryKey: &h.queryKey, + custom: h.custom, }, api.WithClient(s.Client())) require.NoError(t, err) @@ -217,6 +243,22 @@ func TestSecurity(t *testing.T) { require.NoError(t, client.IntersectSecurity(context.Background())) }) + t.Run("CustomSecurity", func(t *testing.T) { + resp := sendReq(t, "/customSecurity", nil) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + resp = sendReq(t, "/customSecurity", func(r *http.Request) { + r.Header.Set(customSecurityHeader, "wrong-token") + }) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + resp = sendReq(t, "/customSecurity", func(r *http.Request) { + r.Header.Set(customSecurityHeader, h.custom) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + + require.NoError(t, client.CustomSecurity(context.Background())) + }) } func TestSecurityClientCheck(t *testing.T) {