Skip to content

Commit

Permalink
Merge pull request #1058 from tdakkota/feat/custom-security
Browse files Browse the repository at this point in the history
feat(parser): add `x-ogen-custom-security` extension
  • Loading branch information
ernado authored Oct 2, 2023
2 parents 76f91a2 + 4135e32 commit 28d1402
Show file tree
Hide file tree
Showing 17 changed files with 481 additions and 39 deletions.
22 changes: 21 additions & 1 deletion _testdata/positive/security.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@
}
}
}
},
"/customSecurity": {
"get": {
"operationId": "customSecurity",
"security": [
{
"custom": []
}
],
"responses": {
"200": {
"description": "OK"
}
}
}
}
},
"components": {
Expand All @@ -86,7 +101,12 @@
"bearerToken": {
"type": "http",
"scheme": "bearer"
},
"custom": {
"type": "http",
"scheme": "digest",
"x-ogen-custom-security": true
}
}
}
}
}
71 changes: 45 additions & 26 deletions gen/_template/security.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ type SecurityHandler interface {
{{- range $name, $s := $.Securities }}
// Handle{{ $s.Type.Name }} handles {{ $name }} security.
{{- template "godoc" $s.GoDoc }}
{{- if $s.Format.IsCustomSecurity }}
Handle{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) (context.Context, error)
{{- else }}
Handle{{ $s.Type.Name }}(ctx context.Context, operationName string, t {{ $s.Type.Name }}) (context.Context, error)
{{- end }}
{{- end }}
}

Expand Down Expand Up @@ -42,8 +46,8 @@ var oauth2Scopes = map[string][]string {

{{- end }}
func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) (context.Context, bool, error) {
var t {{ $s.Type.Name }}
{{- if $s.Format.IsAPIKeySecurity }}
var t {{ $s.Type.Name }}
const parameterName = {{ quote $s.ParameterName }}
{{- if $s.Kind.IsHeader }}
value := req.Header.Get(parameterName)
Expand Down Expand Up @@ -71,6 +75,7 @@ func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName s
{{- end }}
t.APIKey = value
{{- else if $s.Format.IsBasicHTTPSecurity }}
var t {{ $s.Type.Name }}
if _, ok := findAuthorization(req.Header, "Basic"); !ok {
return ctx, false, nil
}
Expand All @@ -81,18 +86,22 @@ func (s *Server) security{{ $s.Type.Name }}(ctx context.Context, operationName s
t.Username = username
t.Password = password
{{- else if $s.Format.IsBearerSecurity }}
var t {{ $s.Type.Name }}
token, ok := findAuthorization(req.Header, "Bearer")
if !ok {
return ctx, false, nil
}
t.Token = token
{{- else if $s.Format.IsOAuth2Security }}
var t {{ $s.Type.Name }}
token, ok := findAuthorization(req.Header, "Bearer")
if !ok {
return ctx, false, nil
}
t.Token = token
t.Scopes = oauth2Scopes[operationName]
{{- else if $s.Format.IsCustomSecurity }}
t := req
{{- else }}
{{ errorf "unexpected security %q:%q" $s.Kind $s.Format }}
{{- end }}
Expand All @@ -113,39 +122,49 @@ type SecuritySource interface {
{{- range $name, $s := $.Securities }}
// {{ $s.Type.Name }} provides {{ $name }} security value.
{{- template "godoc" $s.GoDoc }}
{{- if $s.Format.IsCustomSecurity }}
{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) error
{{- else }}
{{ $s.Type.Name }}(ctx context.Context, operationName string) ({{ $s.Type.Name }}, error)
{{- end }}
{{- end }}
}

{{- range $s := $.Securities }}
func (s *Client) security{{ $s.Type.Name }}(ctx context.Context, operationName string, req *http.Request) error {
t, err := s.sec.{{ $s.Type.Name }}(ctx, operationName)
if err != nil {
return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }})
}
{{- if $s.Format.IsAPIKeySecurity }}
{{- if $s.Kind.IsHeader }}
req.Header.Set({{ quote $s.ParameterName }}, t.APIKey)
{{- else if $s.Kind.IsQuery }}
q := req.URL.Query()
q.Set({{ quote $s.ParameterName }}, t.APIKey)
req.URL.RawQuery = q.Encode()
{{- else if $s.Kind.IsCookie }}
req.AddCookie(&http.Cookie{
Name: {{ quote $s.ParameterName }},
Value: t.APIKey,
})
{{- if $s.Format.IsCustomSecurity }}
if err := s.sec.{{ $s.Type.Name }}(ctx, operationName, req); err != nil {
return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }})
}
{{- else }}
t, err := s.sec.{{ $s.Type.Name }}(ctx, operationName)
if err != nil {
return errors.Wrap(err, {{ printf "security source %q" $s.Type.Name | quote }})
}
{{- if $s.Format.IsAPIKeySecurity }}
{{- if $s.Kind.IsHeader }}
req.Header.Set({{ quote $s.ParameterName }}, t.APIKey)
{{- else if $s.Kind.IsQuery }}
q := req.URL.Query()
q.Set({{ quote $s.ParameterName }}, t.APIKey)
req.URL.RawQuery = q.Encode()
{{- else if $s.Kind.IsCookie }}
req.AddCookie(&http.Cookie{
Name: {{ quote $s.ParameterName }},
Value: t.APIKey,
})
{{- else }}
{{ errorf "unexpected security %q:%q" $s.Kind $s.Format }}
{{- end }}
{{- else if $s.Format.IsBasicHTTPSecurity }}
req.SetBasicAuth(t.Username, t.Password)
{{- else if $s.Format.IsBearerSecurity }}
req.Header.Set("Authorization", "Bearer " + t.Token)
{{- else if $s.Format.IsOAuth2Security }}
req.Header.Set("Authorization", "Bearer " + t.Token)
{{- else }}
{{ errorf "unexpected security %q:%q" $s.Kind $s.Format }}
{{ errorf "unexpected security %q:%q" $s.Kind $s.Format }}
{{- end }}
{{- else if $s.Format.IsBasicHTTPSecurity }}
req.SetBasicAuth(t.Username, t.Password)
{{- else if $s.Format.IsBearerSecurity }}
req.Header.Set("Authorization", "Bearer " + t.Token)
{{- else if $s.Format.IsOAuth2Security }}
req.Header.Set("Authorization", "Bearer " + t.Token)
{{- else }}
{{ errorf "unexpected security %q:%q" $s.Kind $s.Format }}
{{- end }}
return nil
}
Expand Down
7 changes: 7 additions & 0 deletions gen/gen_security.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ func (g *Generator) generateSecurity(ctx *genctx, operationName string, spec ope
Type: t,
Description: security.Description,
}

// Do not create a type for custom security.
if security.XOgenCustomSecurity {
s.Format = ir.CustomSecurityFormat
return s, nil
}

defer func() {
if rErr == nil {
if err := ctx.saveType(t); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions gen/ir/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ const (

// Oauth2SecurityFormat is Oauth2 security format.
Oauth2SecurityFormat SecurityFormat = "oauth2"

// CustomSecurityFormat is a user-defined security format.
CustomSecurityFormat = "x-ogen-custom-security"
)

// IsAPIKeySecurity whether s is APIKeySecurityFormat.
Expand Down Expand Up @@ -84,6 +87,11 @@ func (s SecurityFormat) IsOAuth2Security() bool {
return s == Oauth2SecurityFormat
}

// IsCustomSecurity whether s is CustomSecurityFormat.
func (s SecurityFormat) IsCustomSecurity() bool {
return s == CustomSecurityFormat
}

type Security struct {
Kind SecurityKind
Format SecurityFormat
Expand Down
42 changes: 42 additions & 0 deletions internal/integration/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ import (
"github.com/ogen-go/ogen/ogenerrors"
)

const customSecurityHeader = "X-Foo-Custom"

type testSecurity struct {
basicAuth api.BasicAuth
bearerToken api.BearerToken
headerKey api.HeaderKey
queryKey api.QueryKey
cookieKey api.CookieKey
custom string
}

func (t *testSecurity) OptionalSecurity(ctx context.Context) error {
Expand All @@ -34,6 +37,10 @@ func (t *testSecurity) IntersectSecurity(ctx context.Context) error {
return nil
}

func (t *testSecurity) CustomSecurity(ctx context.Context) error {
return nil
}

type tokenKey string

func (t *testSecurity) HandleBasicAuth(ctx context.Context, operationName string, v api.BasicAuth) (context.Context, error) {
Expand Down Expand Up @@ -71,12 +78,21 @@ func (t *testSecurity) HandleCookieKey(ctx context.Context, operationName string
return context.WithValue(ctx, tokenKey("CookieKey"), v), nil
}

func (t *testSecurity) HandleCustom(ctx context.Context, operationName string, req *http.Request) (context.Context, error) {
got := req.Header.Get(customSecurityHeader)
if got != t.custom {
return nil, errors.Errorf("invalid custom auth: %q", got)
}
return context.WithValue(ctx, tokenKey("Custom"), got), nil
}

type testSecuritySource struct {
basicAuth *api.BasicAuth
bearerToken *api.BearerToken
headerKey *api.HeaderKey
queryKey *api.QueryKey
cookieKey *api.CookieKey
custom string
}

func (t *testSecuritySource) BasicAuth(ctx context.Context, operationName string) (r api.BasicAuth, _ error) {
Expand Down Expand Up @@ -114,13 +130,22 @@ func (t *testSecuritySource) CookieKey(ctx context.Context, operationName string
return r, ogenerrors.ErrSkipClientSecurity
}

func (t *testSecuritySource) Custom(ctx context.Context, operationName string, req *http.Request) error {
if t.custom == "" {
return ogenerrors.ErrSkipClientSecurity
}
req.Header.Set(customSecurityHeader, t.custom)
return nil
}

func TestSecurity(t *testing.T) {
h := &testSecurity{
basicAuth: api.BasicAuth{Username: "username", Password: "password"},
bearerToken: api.BearerToken{Token: "BearerToken"},
headerKey: api.HeaderKey{APIKey: "HeaderKey"},
queryKey: api.QueryKey{APIKey: "QueryKey"},
cookieKey: api.CookieKey{APIKey: "CookieKey"},
custom: "foobar-custom-token",
}
srv, err := api.NewServer(h, h)
require.NoError(t, err)
Expand All @@ -135,6 +160,7 @@ func TestSecurity(t *testing.T) {
bearerToken: &h.bearerToken,
headerKey: &h.headerKey,
queryKey: &h.queryKey,
custom: h.custom,
}, api.WithClient(s.Client()))
require.NoError(t, err)

Expand Down Expand Up @@ -217,6 +243,22 @@ func TestSecurity(t *testing.T) {

require.NoError(t, client.IntersectSecurity(context.Background()))
})
t.Run("CustomSecurity", func(t *testing.T) {
resp := sendReq(t, "/customSecurity", nil)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)

resp = sendReq(t, "/customSecurity", func(r *http.Request) {
r.Header.Set(customSecurityHeader, "wrong-token")
})
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)

resp = sendReq(t, "/customSecurity", func(r *http.Request) {
r.Header.Set(customSecurityHeader, h.custom)
})
require.Equal(t, http.StatusOK, resp.StatusCode)

require.NoError(t, client.CustomSecurity(context.Background()))
})
}

func TestSecurityClientCheck(t *testing.T) {
Expand Down
Loading

0 comments on commit 28d1402

Please sign in to comment.