Skip to content

Commit

Permalink
add context
Browse files Browse the repository at this point in the history
  • Loading branch information
rucciva committed Aug 24, 2024
1 parent 8e5569b commit 92716c6
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions internal/proxy/testoutput/oapi-proxy-impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func (s serverImpl) GetProfile(ctx context.Context, request GetProfileRequestObj
return []authzExpectation{
e.ProfileIDNotZero(request.ProfileId),
e.OR(
func() (bool, error) { return true, nil },
func(context.Context) (bool, error) { return true, nil },
e.AND(
func() (bool, error) { return true, nil },
func(context.Context) (bool, error) { return true, nil },
e.False(),
),
),
Expand Down Expand Up @@ -113,9 +113,11 @@ func authz() strictecho.StrictEchoMiddlewareFunc {
return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
// init expectations
exp := authzExpectations{}
exp.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID)
ctx.SetRequest(ctx.Request().WithContext(exp.Attach(ctx.Request().Context())))
ctx.SetRequest(
ctx.Request().WithContext(
ctxWithAuthzExpectation(ctx.Request().Context(), func(ae *authzExpectations) {
ae.tenantID, _ = ctx.Request().Context().Value(ctxTenantID{}).(uuid.UUID)
})))

// exec handler
res, err := f(ctx, request)
Expand All @@ -136,7 +138,7 @@ func authz() strictecho.StrictEchoMiddlewareFunc {
}
}

type authzExpectation func() (bool, error)
type authzExpectation func(ctx context.Context) (bool, error)

type authzExpectations struct {
tenantID uuid.UUID
Expand All @@ -146,22 +148,14 @@ type authzExpectations struct {
err error
}

func (a *authzExpectations) Attach(ctx context.Context) context.Context {
if a == nil {
return ctx
}

return context.WithValue(ctx, authzExpectations{}, a)
}

func (a *authzExpectations) OR(reqs ...func() (bool, error)) authzExpectation {
func (a *authzExpectations) OR(reqs ...authzExpectation) authzExpectation {
if a == nil {
a.False()
}

return func() (oks bool, errs error) {
return func(ctx context.Context) (oks bool, errs error) {
for _, step := range reqs {
ok, err := step()
ok, err := step(ctx)
if !ok || err != nil {
errs = errors.Join(errs, err)
continue
Expand All @@ -173,14 +167,14 @@ func (a *authzExpectations) OR(reqs ...func() (bool, error)) authzExpectation {
}
}

func (a *authzExpectations) AND(reqs ...func() (bool, error)) authzExpectation {
func (a *authzExpectations) AND(reqs ...authzExpectation) authzExpectation {
if a == nil {
a.False()
}

return func() (bool, error) {
return func(ctx context.Context) (bool, error) {
for _, step := range reqs {
ok, err := step()
ok, err := step(ctx)
if !ok || err != nil {
return ok, err
}
Expand All @@ -190,15 +184,15 @@ func (a *authzExpectations) AND(reqs ...func() (bool, error)) authzExpectation {
}

func (a *authzExpectations) False() authzExpectation {
return func() (bool, error) { return false, nil }
return func(context.Context) (bool, error) { return false, nil }
}

func (a *authzExpectations) tenantIDNotZero() authzExpectation {
if a == nil {
a.False()
}

return func() (bool, error) {
return func(context.Context) (bool, error) {
return a.tenantID != uuid.UUID{}, nil
}
}
Expand All @@ -208,11 +202,19 @@ func (a *authzExpectations) ProfileIDNotZero(profileID uuid.UUID) authzExpectati
a.False()
}

return func() (bool, error) {
return func(context.Context) (bool, error) {
return profileID != uuid.UUID{}, nil
}
}

func ctxWithAuthzExpectation(ctx context.Context, init func(*authzExpectations)) context.Context {
a := &authzExpectations{}
init(a)

return context.WithValue(ctx, authzExpectations{}, a)

}

func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExpectation) (bool, error) {
v, _ := (ctx.Value(authzExpectations{})).(*authzExpectations)
if v == nil {
Expand All @@ -232,7 +234,7 @@ func authzExpect(ctx context.Context, addExp func(*authzExpectations) []authzExp

v.invoked = true
for _, exp := range exps {
v.result, v.err = exp()
v.result, v.err = exp(ctx)
if !v.result || v.err != nil {
return v.result, v.err
}
Expand Down

0 comments on commit 92716c6

Please sign in to comment.