Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rucciva committed Aug 24, 2024
1 parent 77e79ac commit 8e5569b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
53 changes: 36 additions & 17 deletions internal/proxy/testoutput/oapi-proxy-impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (s serverImpl) GetProfile(ctx context.Context, request GetProfileRequestObj

// GetValidatedProfile implements StrictServerInterface.
func (s serverImpl) GetValidatedProfile(ctx context.Context, request GetValidatedProfileRequestObject) (UpstreamProfileGetProfileRequestObject, error) {
authzExpect(ctx, func(ae *authzExpectations) []authzExpectation { return []authzExpectation{ae.False()} })
return UpstreamProfileGetProfileRequestObject{
TenantId: ctx.Value(ctxTenantID{}).(uuid.UUID),
ProfileId: request.ProfileId,
Expand All @@ -77,7 +78,7 @@ func (s serverImpl) PutProfile(ctx context.Context, request PutProfileRequestObj

type ctxTenantID struct{}

func insertTenantIDMiddleware(tenantID uuid.UUID) strictecho.StrictEchoMiddlewareFunc {
func injectTenantIDMiddleware(tenantID uuid.UUID) strictecho.StrictEchoMiddlewareFunc {
return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
ctx.SetRequest(
Expand Down Expand Up @@ -111,9 +112,9 @@ func selectivePasstroughMiddleware() strictecho.StrictEchoMiddlewareFunc {
func authz() strictecho.StrictEchoMiddlewareFunc {
return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
exp := authzExpectations{
tenantID: ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID),
}
// init expectations
exp := authzExpectations{}
exp.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID)
ctx.SetRequest(ctx.Request().WithContext(exp.Attach(ctx.Request().Context())))

// exec handler
Expand All @@ -122,11 +123,12 @@ func authz() strictecho.StrictEchoMiddlewareFunc {
return nil, err
}

// assert permission result
if exp.err != nil {
// verify expectations result
result, err := authzExpect(ctx.Request().Context(), nil)
if err != nil {
return nil, err
}
if !exp.result {
if !result {
return nil, echo.NewHTTPError(http.StatusForbidden, "forbidden")
}
return res, err
Expand All @@ -139,8 +141,9 @@ type authzExpectation func() (bool, error)
type authzExpectations struct {
tenantID uuid.UUID

result bool
err error
invoked bool
result bool
err error
}

func (a *authzExpectations) Attach(ctx context.Context) context.Context {
Expand Down Expand Up @@ -190,6 +193,16 @@ func (a *authzExpectations) False() authzExpectation {
return func() (bool, error) { return false, nil }
}

func (a *authzExpectations) tenantIDNotZero() authzExpectation {
if a == nil {
a.False()
}

return func() (bool, error) {
return a.tenantID != uuid.UUID{}, nil
}
}

func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectation {
if a == nil {
a.False()
Expand All @@ -200,25 +213,31 @@ func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectati
}
}

func authzExpect(ctx context.Context, f func(*authzExpectations) []authzExpectation) (bool, error) {
func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExpectation) (bool, error) {
v, _ := (ctx.Value(authzExpectations{})).(*authzExpectations)
if v == nil {
return false, nil
}

// required expectation
if v.tenantID == (uuid.UUID{}) {
return false, nil
if v.invoked {
return v.result, v.err
}

// additional expectation
for _, req := range f(v) {
v.result, v.err = req()
// required expectations
exps := []authzExpectation{v.tenantIDNotZero()}
// additional expectations
if addExp != nil {
exps = append(exps, addExp(v)...)
}

v.invoked = true
for _, exp := range exps {
v.result, v.err = exp()
if !v.result || v.err != nil {
return v.result, v.err
}
}

v.result = true

return v.result, nil
}
38 changes: 22 additions & 16 deletions internal/proxy/testoutput/oapi-proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestProxy(t *testing.T) {

t.Run("Standard", func(t *testing.T) {
e := echo.New()
sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{insertTenantIDMiddleware(tenantID)})
sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{injectTenantIDMiddleware(tenantID)})
RegisterHandlers(e, sh)

id := uuid.NewString()
Expand Down Expand Up @@ -69,7 +69,7 @@ func TestProxy(t *testing.T) {
e := echo.New()
sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{
selectivePasstroughMiddleware(),
insertTenantIDMiddleware(tenantID),
injectTenantIDMiddleware(tenantID),
})
RegisterHandlers(e, sh)

Expand Down Expand Up @@ -109,34 +109,40 @@ func TestProxy(t *testing.T) {
e := echo.New()
sh := NewStrictHandler(serverImpl, proxyImpl, []strictecho.StrictEchoMiddlewareFunc{
authz(),
insertTenantIDMiddleware(tenantID),
injectTenantIDMiddleware(tenantID),
})
RegisterHandlers(e, sh)

id := uuid.NewString()
testtable := []struct {
name string
i string
o string
code int
name string
path string
method string
code int
}{
{
name: "Authorized",
i: "/profiles/" + id,
o: "/tenants/" + tenantID.String() + "/profiles/" + id,
code: http.StatusAccepted,
name: "Authorized",
path: "/profiles/" + id,
method: http.MethodGet,
code: http.StatusAccepted,
},
{
name: "NotAuthorized",
i: "/validated-profiles/" + id,
o: "",
code: http.StatusForbidden,
name: "AuthorizedWithoutAdditionalExpectation",
path: "/profiles/" + id,
method: http.MethodPut,
code: http.StatusAccepted,
},
{
name: "Unauthorized",
path: "/validated-profiles/" + id,
method: http.MethodGet,
code: http.StatusForbidden,
},
}

for _, d := range testtable {
t.Run(d.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, d.i, nil)
req := httptest.NewRequest(d.method, d.path, nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

Expand Down

0 comments on commit 8e5569b

Please sign in to comment.