diff --git a/internal/proxy/testoutput/oapi-proxy-impl.go b/internal/proxy/testoutput/oapi-proxy-impl.go index 4bcb9e5..4837e7f 100644 --- a/internal/proxy/testoutput/oapi-proxy-impl.go +++ b/internal/proxy/testoutput/oapi-proxy-impl.go @@ -35,9 +35,9 @@ func (s serverImpl) GetProfile(ctx context.Context, request GetProfileRequestObj return []authzExpectation{ e.ProfileIDNotZero(request.ProfileId), e.OR( - func() (bool, error) { return true, nil }, + func(context.Context) (bool, error) { return true, nil }, e.AND( - func() (bool, error) { return true, nil }, + func(context.Context) (bool, error) { return true, nil }, e.False(), ), ), @@ -113,9 +113,11 @@ func authz() strictecho.StrictEchoMiddlewareFunc { return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { return func(ctx echo.Context, request interface{}) (response interface{}, err error) { // init expectations - exp := authzExpectations{} - exp.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID) - ctx.SetRequest(ctx.Request().WithContext(exp.Attach(ctx.Request().Context()))) + ctx.SetRequest( + ctx.Request().WithContext( + ctxWithAuthzExpectation(ctx.Request().Context(), func(ae *authzExpectations) { + ae.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID) + }))) // exec handler res, err := f(ctx, request) @@ -136,7 +138,7 @@ func authz() strictecho.StrictEchoMiddlewareFunc { } } -type authzExpectation func() (bool, error) +type authzExpectation func(ctx context.Context) (bool, error) type authzExpectations struct { tenantID uuid.UUID @@ -146,22 +148,14 @@ type authzExpectations struct { err error } -func (a *authzExpectations) Attach(ctx context.Context) context.Context { - if a == nil { - return ctx - } - - return context.WithValue(ctx, authzExpectations{}, a) -} - -func (a *authzExpectations) OR(reqs ...func() (bool, error)) authzExpectation { +func (a *authzExpectations) OR(reqs ...authzExpectation) authzExpectation { if a == nil { a.False() } - return func() (oks bool, errs error) { + return func(ctx context.Context) (oks bool, errs error) { for _, step := range reqs { - ok, err := step() + ok, err := step(ctx) if !ok || err != nil { errs = errors.Join(errs, err) continue @@ -173,14 +167,14 @@ func (a *authzExpectations) OR(reqs ...func() (bool, error)) authzExpectation { } } -func (a *authzExpectations) AND(reqs ...func() (bool, error)) authzExpectation { +func (a *authzExpectations) AND(reqs ...authzExpectation) authzExpectation { if a == nil { a.False() } - return func() (bool, error) { + return func(ctx context.Context) (bool, error) { for _, step := range reqs { - ok, err := step() + ok, err := step(ctx) if !ok || err != nil { return ok, err } @@ -190,7 +184,7 @@ func (a *authzExpectations) AND(reqs ...func() (bool, error)) authzExpectation { } func (a *authzExpectations) False() authzExpectation { - return func() (bool, error) { return false, nil } + return func(context.Context) (bool, error) { return false, nil } } func (a *authzExpectations) tenantIDNotZero() authzExpectation { @@ -198,7 +192,7 @@ func (a *authzExpectations) tenantIDNotZero() authzExpectation { a.False() } - return func() (bool, error) { + return func(context.Context) (bool, error) { return a.tenantID != uuid.UUID{}, nil } } @@ -208,11 +202,19 @@ func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectati a.False() } - return func() (bool, error) { + return func(context.Context) (bool, error) { return profileID != uuid.UUID{}, nil } } +func ctxWithAuthzExpectation(ctx context.Context, init func(*authzExpectations)) context.Context { + a := &authzExpectations{} + init(a) + + return context.WithValue(ctx, authzExpectations{}, a) + +} + func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExpectation) (bool, error) { v, _ := (ctx.Value(authzExpectations{})).(*authzExpectations) if v == nil { @@ -232,7 +234,7 @@ func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExp v.invoked = true for _, exp := range exps { - v.result, v.err = exp() + v.result, v.err = exp(ctx) if !v.result || v.err != nil { return v.result, v.err }