diff --git a/internal/proxy/testoutput/oapi-proxy-impl.go b/internal/proxy/testoutput/oapi-proxy-impl.go index 6050cae..4bcb9e5 100644 --- a/internal/proxy/testoutput/oapi-proxy-impl.go +++ b/internal/proxy/testoutput/oapi-proxy-impl.go @@ -55,6 +55,7 @@ func (s serverImpl) GetProfile(ctx context.Context, request GetProfileRequestObj // GetValidatedProfile implements StrictServerInterface. func (s serverImpl) GetValidatedProfile(ctx context.Context, request GetValidatedProfileRequestObject) (UpstreamProfileGetProfileRequestObject, error) { + authzExpect(ctx, func(ae *authzExpectations) []authzExpectation { return []authzExpectation{ae.False()} }) return UpstreamProfileGetProfileRequestObject{ TenantId: ctx.Value(ctxTenantID{}).(uuid.UUID), ProfileId: request.ProfileId, @@ -77,7 +78,7 @@ func (s serverImpl) PutProfile(ctx context.Context, request PutProfileRequestObj type ctxTenantID struct{} -func insertTenantIDMiddleware(tenantID uuid.UUID) strictecho.StrictEchoMiddlewareFunc { +func injectTenantIDMiddleware(tenantID uuid.UUID) strictecho.StrictEchoMiddlewareFunc { return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { return func(ctx echo.Context, request interface{}) (response interface{}, err error) { ctx.SetRequest( @@ -111,9 +112,9 @@ func selectivePasstroughMiddleware() strictecho.StrictEchoMiddlewareFunc { func authz() strictecho.StrictEchoMiddlewareFunc { return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { return func(ctx echo.Context, request interface{}) (response interface{}, err error) { - exp := authzExpectations{ - tenantID: ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID), - } + // init expectations + exp := authzExpectations{} + exp.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID) ctx.SetRequest(ctx.Request().WithContext(exp.Attach(ctx.Request().Context()))) // exec handler @@ -122,11 +123,12 @@ func authz() strictecho.StrictEchoMiddlewareFunc { return nil, err } - // assert permission result - if exp.err != nil { + // verify expectations result + result, err := authzExpect(ctx.Request().Context(), nil) + if err != nil { return nil, err } - if !exp.result { + if !result { return nil, echo.NewHTTPError(http.StatusForbidden, "forbidden") } return res, err @@ -139,8 +141,9 @@ type authzExpectation func() (bool, error) type authzExpectations struct { tenantID uuid.UUID - result bool - err error + invoked bool + result bool + err error } func (a *authzExpectations) Attach(ctx context.Context) context.Context { @@ -190,6 +193,16 @@ func (a *authzExpectations) False() authzExpectation { return func() (bool, error) { return false, nil } } +func (a *authzExpectations) tenantIDNotZero() authzExpectation { + if a == nil { + a.False() + } + + return func() (bool, error) { + return a.tenantID != uuid.UUID{}, nil + } +} + func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectation { if a == nil { a.False() @@ -200,25 +213,31 @@ func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectati } } -func authzExpect(ctx context.Context, f func(*authzExpectations) []authzExpectation) (bool, error) { +func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExpectation) (bool, error) { v, _ := (ctx.Value(authzExpectations{})).(*authzExpectations) if v == nil { return false, nil } - // required expectation - if v.tenantID == (uuid.UUID{}) { - return false, nil + if v.invoked { + return v.result, v.err } - // additional expectation - for _, req := range f(v) { - v.result, v.err = req() + // required expectations + exps := []authzExpectation{v.tenantIDNotZero()} + // additional expectations + if addExp != nil { + exps = append(exps, addExp(v)...) + } + + v.invoked = true + for _, exp := range exps { + v.result, v.err = exp() if !v.result || v.err != nil { return v.result, v.err } } - v.result = true + return v.result, nil } diff --git a/internal/proxy/testoutput/oapi-proxy_test.go b/internal/proxy/testoutput/oapi-proxy_test.go index a7d61d2..7971cca 100644 --- a/internal/proxy/testoutput/oapi-proxy_test.go +++ b/internal/proxy/testoutput/oapi-proxy_test.go @@ -31,7 +31,7 @@ func TestProxy(t *testing.T) { t.Run("Standard", func(t *testing.T) { e := echo.New() - sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{insertTenantIDMiddleware(tenantID)}) + sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{injectTenantIDMiddleware(tenantID)}) RegisterHandlers(e, sh) id := uuid.NewString() @@ -69,7 +69,7 @@ func TestProxy(t *testing.T) { e := echo.New() sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{ selectivePasstroughMiddleware(), - insertTenantIDMiddleware(tenantID), + injectTenantIDMiddleware(tenantID), }) RegisterHandlers(e, sh) @@ -109,34 +109,40 @@ func TestProxy(t *testing.T) { e := echo.New() sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{ authz(), - insertTenantIDMiddleware(tenantID), + injectTenantIDMiddleware(tenantID), }) RegisterHandlers(e, sh) id := uuid.NewString() testtable := []struct { - name string - i string - o string - code int + name string + path string + method string + code int }{ { - name: "Authorized", - i: "/profiles/" + id, - o: "/tenants/" + tenantID.String() + "/profiles/" + id, - code: http.StatusAccepted, + name: "Authorized", + path: "/profiles/" + id, + method: http.MethodGet, + code: http.StatusAccepted, }, { - name: "NotAuthorized", - i: "/validated-profiles/" + id, - o: "", - code: http.StatusForbidden, + name: "AuthorizedWithoutAdditionalExpectation", + path: "/profiles/" + id, + method: http.MethodPut, + code: http.StatusAccepted, + }, + { + name: "Unauthorized", + path: "/validated-profiles/" + id, + method: http.MethodGet, + code: http.StatusForbidden, }, } for _, d := range testtable { t.Run(d.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, d.i, nil) + req := httptest.NewRequest(d.method, d.path, nil) res := httptest.NewRecorder() e.ServeHTTP(res, req)