From 4c4b4464ca079051148ca9c3db995c9979ceb54c Mon Sep 17 00:00:00 2001 From: James Date: Tue, 29 Mar 2022 11:02:02 +0700 Subject: [PATCH] feat: pass input to validation function to allow checks on challenge inputs (#17) --- examples/graphql_flow/flows/single_flow.go | 2 +- .../flows/single_flow.go | 2 +- .../single_flow_single_challenge/flows/single_flow.go | 2 +- flow/interface.go | 2 +- mfa/mfa.go | 8 ++++---- mfa/mfa_test.go | 10 +++++----- mocks/mock_flow.go | 8 ++++---- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/graphql_flow/flows/single_flow.go b/examples/graphql_flow/flows/single_flow.go index 5d04697..9968821 100644 --- a/examples/graphql_flow/flows/single_flow.go +++ b/examples/graphql_flow/flows/single_flow.go @@ -26,7 +26,7 @@ func (f SingleFlow) Initialize(ctx context.Context) (*JWTEntities.JWTAdditions, }, nil } -func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData) error { +func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData, challengeInput *string) error { //TODO implement me return nil } diff --git a/examples/single_flow_multiple_challenges/flows/single_flow.go b/examples/single_flow_multiple_challenges/flows/single_flow.go index f4c0ede..c89ad63 100644 --- a/examples/single_flow_multiple_challenges/flows/single_flow.go +++ b/examples/single_flow_multiple_challenges/flows/single_flow.go @@ -22,7 +22,7 @@ func (f SingleFlow) Resolve(jwtData mfaEntities.JWTData) (*map[string]interface{ }, nil } -func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData) (context.Context, error) { +func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData, challengeInput *string) (context.Context, error) { //TODO implement me return ctx, nil } diff --git a/examples/single_flow_single_challenge/flows/single_flow.go b/examples/single_flow_single_challenge/flows/single_flow.go index 5acedc6..0120f46 100644 --- a/examples/single_flow_single_challenge/flows/single_flow.go +++ b/examples/single_flow_single_challenge/flows/single_flow.go @@ -26,7 +26,7 @@ func (f SingleFlow) Initialize(ctx context.Context) (*JWTEntities.JWTAdditions, }, nil } -func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData) (context.Context, error) { +func (f SingleFlow) Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData, challengeInput *string) (context.Context, error) { //TODO implement me return ctx, nil } diff --git a/flow/interface.go b/flow/interface.go index f5fc7f3..5d48668 100644 --- a/flow/interface.go +++ b/flow/interface.go @@ -11,7 +11,7 @@ type IFlow interface { Solve(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) Request(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) Resolve(JWTData mfaEntities.JWTData) (*map[string]interface{}, error) - Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData) (context.Context, error) + Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData, challengeInput *string) (context.Context, error) GetChallenges() []string GetName() string Initialize(ctx context.Context) (*entities.JWTAdditions, error) diff --git a/mfa/mfa.go b/mfa/mfa.go index 9e61d5f..85b9d09 100644 --- a/mfa/mfa.go +++ b/mfa/mfa.go @@ -47,7 +47,7 @@ func (m *Service) decodeJWT(jwt string) (*entities.JWTData, error) { return &decodedJWT, nil } -func (m *Service) getFlow(ctx context.Context, flow string, decodedJWT *entities.JWTData, challenge *string) (context.Context, flow.IFlow, error) { +func (m *Service) getFlow(ctx context.Context, flow string, decodedJWT *entities.JWTData, challenge *string, input *string) (context.Context, flow.IFlow, error) { requestedFlow := m.Flows[flow] if requestedFlow == nil { return ctx, nil, errors.New("Flow not found") @@ -56,7 +56,7 @@ func (m *Service) getFlow(ctx context.Context, flow string, decodedJWT *entities if challenge == nil { return ctx, requestedFlow, nil } - newCtx, err := requestedFlow.Validate(ctx, *challenge, *decodedJWT) + newCtx, err := requestedFlow.Validate(ctx, *challenge, *decodedJWT, input) if err != nil { return ctx, nil, err } @@ -69,7 +69,7 @@ func (m *Service) Process(ctx context.Context, jwt string, challenge string, inp if err != nil { return nil, err } - newCtx, requestFlow, err := m.getFlow(ctx, decodedJWT.Flow, decodedJWT, &challenge) + newCtx, requestFlow, err := m.getFlow(ctx, decodedJWT.Flow, decodedJWT, &challenge, &input) if err != nil { return nil, err } @@ -101,7 +101,7 @@ func (m *Service) setContext(ctx context.Context, requestFlow flow.IFlow, input } func (m *Service) Request(ctx context.Context, flow string, input *FlowInput) (*entities.MFAResult, error) { - newCtx, requestFlow, err := m.getFlow(ctx, flow, &entities.JWTData{}, nil) + newCtx, requestFlow, err := m.getFlow(ctx, flow, &entities.JWTData{}, nil, nil) if err != nil { return nil, err } diff --git a/mfa/mfa_test.go b/mfa/mfa_test.go index c9f5f11..407df75 100644 --- a/mfa/mfa_test.go +++ b/mfa/mfa_test.go @@ -57,7 +57,7 @@ func TestNewMFAService(t *testing.T) { jwtService.EXPECT().GenerateToken(gomock.Any(), gomock.Any()).Return(validJWT, nil) mockflow.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) mockflow.EXPECT().GetName().Return("test") mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) @@ -138,7 +138,7 @@ func TestNewMFAService(t *testing.T) { jwtService.EXPECT().GenerateToken(gomock.Any(), gomock.Any()).Return(validJWT, nil) mockflow.EXPECT().Solve(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) mockflow.EXPECT().GetName().Return("test") mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) @@ -175,7 +175,7 @@ func TestNewMFAService(t *testing.T) { jwtService.EXPECT().GenerateToken(gomock.Any(), gomock.Any()).Return(validJWT, nil) mockflow.EXPECT().Solve(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("Failed to solve")) - mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) mockflow.EXPECT().GetName().Return("test") mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) @@ -365,7 +365,7 @@ func TestNewMFAService(t *testing.T) { jwtService.EXPECT().GenerateToken(gomock.Any(), gomock.Any()).Return(validJWT, nil) mockflow.EXPECT().Solve(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) mockflow.EXPECT().GetName().Return("test") mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) @@ -401,7 +401,7 @@ func TestNewMFAService(t *testing.T) { jwtService.EXPECT().GenerateToken(gomock.Any(), gomock.Any()).Return(validJWT, nil) mockflow.EXPECT().Solve(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockflow.EXPECT().Validate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) mockflow.EXPECT().GetName().Return("test") mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) mockflow.EXPECT().GetChallenges().Return([]string{"dummy"}) diff --git a/mocks/mock_flow.go b/mocks/mock_flow.go index 119f0d5..9f0b8c3 100644 --- a/mocks/mock_flow.go +++ b/mocks/mock_flow.go @@ -181,16 +181,16 @@ func (mr *MockIFlowMockRecorder) Solve(arg0, arg1, arg2, arg3 interface{}) *gomo } // Validate mocks base method. -func (m *MockIFlow) Validate(arg0 context.Context, arg1 string, arg2 entities0.JWTData) (context.Context, error) { +func (m *MockIFlow) Validate(arg0 context.Context, arg1 string, arg2 entities0.JWTData, arg3 *string) (context.Context, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Validate", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Validate", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(context.Context) ret1, _ := ret[1].(error) return ret0, ret1 } // Validate indicates an expected call of Validate. -func (mr *MockIFlowMockRecorder) Validate(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockIFlowMockRecorder) Validate(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockIFlow)(nil).Validate), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockIFlow)(nil).Validate), arg0, arg1, arg2, arg3) }