Skip to content

Commit

Permalink
feat: pass ctx to challange and add before hook (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
bludot authored Mar 25, 2022
1 parent 3f547c8 commit 5468a7e
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 70 deletions.
6 changes: 4 additions & 2 deletions challenge/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package challenge

import "context"

type IChallenge interface {
Solve(body map[string]interface{}) (*map[string]interface{}, error)
Request(body map[string]interface{}) (*map[string]interface{}, error) // Request a challenge, ex: for OTP you have to request an OTP before you can solve it
Solve(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error)
Request(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) // Request a challenge, ex: for OTP you have to request an OTP before you can solve it
}
5 changes: 3 additions & 2 deletions examples/single_flow_multiple_challenges/challenges/dummy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package challenges

import (
"context"
"errors"
"log"
"math/rand"
Expand All @@ -25,7 +26,7 @@ type DummyChallenge struct {
Seed string `json:"seed"`
}

func (c *DummyChallenge) Solve(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyChallenge) Solve(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
log.Println("seed:", c.Seed)
log.Println("password:", body["password"])
if body["username"] == "admin" && body["password"].(string) == c.Seed {
Expand All @@ -34,7 +35,7 @@ func (c *DummyChallenge) Solve(body map[string]interface{}) (*map[string]interfa
return nil, errors.New("failed!")
}

func (c *DummyChallenge) Request(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyChallenge) Request(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
rand.Seed(time.Now().UnixNano())
c.Seed = randSeq(10)
log.Println("Seed:", c.Seed)
Expand Down
5 changes: 3 additions & 2 deletions examples/single_flow_multiple_challenges/challenges/dummy2.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package challenges

import (
"context"
"errors"
"log"
"math/rand"
Expand All @@ -15,7 +16,7 @@ type DummyTwoChallenge struct {
Seed string `json:"seed"`
}

func (c *DummyTwoChallenge) Solve(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyTwoChallenge) Solve(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
log.Println("seed:", c.Seed)
log.Println("password:", body["password"])
if body["username"] == "admin" && body["password"].(string) == c.Seed {
Expand All @@ -24,7 +25,7 @@ func (c *DummyTwoChallenge) Solve(body map[string]interface{}) (*map[string]inte
return nil, errors.New("failed!")
}

func (c *DummyTwoChallenge) Request(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyTwoChallenge) Request(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
rand.Seed(time.Now().UnixNano())
c.Seed = randSeq(10)
log.Println("Seed:", c.Seed)
Expand Down
8 changes: 4 additions & 4 deletions examples/single_flow_multiple_challenges/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ func main() {
log.Println(fail)

jwt = res.Token
res, err = mfaService.Process(context.TODO(), jwt, "dummy", fail, false)
res, err = mfaService.Process(context.TODO(), jwt, "dummy", fail, false, nil)
if err != nil {
log.Println("Failed")
}
resJSON, _ = json.Marshal(*res)
log.Println(string(resJSON))

jwt = res.Token
res, err = mfaService.Process(context.TODO(), jwt, "dummy", pass, false)
res, err = mfaService.Process(context.TODO(), jwt, "dummy", pass, false, nil)
if err != nil {
log.Println("Failed")
}
resJSON, _ = json.Marshal(*res)
log.Println(string(resJSON))

jwt = res.Token
res, err = mfaService.Process(context.TODO(), jwt, "dummy2", pass, true)
res, err = mfaService.Process(context.TODO(), jwt, "dummy2", pass, true, nil)
if err != nil {
log.Println("Failed")
}
Expand All @@ -72,7 +72,7 @@ func main() {
log.Println(pass)
log.Println(fail)

res, err = mfaService.Process(context.TODO(), jwt, "dummy2", pass, false)
res, err = mfaService.Process(context.TODO(), jwt, "dummy2", pass, false, nil)
if err != nil {
log.Println("Failed")
}
Expand Down
5 changes: 3 additions & 2 deletions examples/single_flow_single_challenge/challenges/dummy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package challenges

import (
"context"
"errors"
"log"
"math/rand"
Expand All @@ -25,7 +26,7 @@ type DummyChallenge struct {
Seed string `json:"seed"`
}

func (c *DummyChallenge) Solve(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyChallenge) Solve(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
log.Println("seed:", c.Seed)
log.Println("password:", body["password"])
if body["username"] == "admin" && body["password"].(string) == c.Seed {
Expand All @@ -34,7 +35,7 @@ func (c *DummyChallenge) Solve(body map[string]interface{}) (*map[string]interfa
return nil, errors.New("failed!")
}

func (c *DummyChallenge) Request(body map[string]interface{}) (*map[string]interface{}, error) {
func (c *DummyChallenge) Request(ctx context.Context, body map[string]interface{}) (*map[string]interface{}, error) {
rand.Seed(time.Now().UnixNano())
c.Seed = randSeq(10)
log.Println("Seed:", c.Seed)
Expand Down
4 changes: 2 additions & 2 deletions examples/single_flow_single_challenge/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ func main() {
log.Println(fail)

// Attempt to solve
res, err = mfaService.Process(context.TODO(), jwt, "dummy", fail, false)
res, err = mfaService.Process(context.TODO(), jwt, "dummy", fail, false, nil)
if err != nil {
log.Println("Failed")
}
resJSON, _ = json.Marshal(*res)
log.Println(string(resJSON))
jwt = res.Token

res, err = mfaService.Process(context.TODO(), jwt, "dummy", pass, false)
res, err = mfaService.Process(context.TODO(), jwt, "dummy", pass, false, nil)
if err != nil {
log.Println("Failed")
}
Expand Down
9 changes: 5 additions & 4 deletions flow/entities/flow.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package entities

import (
"context"
"encoding/json"
"errors"

Expand All @@ -17,27 +18,27 @@ func (f Flow) GetName() string {
return f.Name
}

func (f Flow) Solve(challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) {
func (f Flow) Solve(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) {
var marshaledInput map[string]interface{}
err := json.Unmarshal([]byte(input), &marshaledInput)
if err != nil {
return nil, err
}
if challenge, ok := f.Challenges[challenge]; ok {
return challenge.Solve(marshaledInput)
return challenge.Solve(ctx, marshaledInput)
}

return nil, errors.New("Challenge not found")
}

func (f Flow) Request(challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) {
func (f Flow) Request(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error) {
var marshaledInput map[string]interface{}
err := json.Unmarshal([]byte(input), &marshaledInput)
if err != nil {
return nil, err
}
if challenge, ok := f.Challenges[challenge]; ok {
return challenge.Request(marshaledInput)
return challenge.Request(ctx, marshaledInput)
}

return nil, errors.New("Challenge not found")
Expand Down
21 changes: 11 additions & 10 deletions flow/entities/flow_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package entities_test

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -44,9 +45,9 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

dummyChallenge.EXPECT().Solve(gomock.Any()).Return(nil, nil)
dummyChallenge.EXPECT().Solve(gomock.Any(), gomock.Any()).Return(nil, nil)

solved, err := flow.Solve("dummy", "{}", mfaEntities.JWTData{
solved, err := flow.Solve(context.TODO(), "dummy", "{}", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -71,9 +72,9 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

dummyChallenge.EXPECT().Solve(gomock.Any()).Return(nil, errors.New("error"))
dummyChallenge.EXPECT().Solve(gomock.Any(), gomock.Any()).Return(nil, errors.New("error"))

solved, err := flow.Solve("dummy", "{}", mfaEntities.JWTData{
solved, err := flow.Solve(context.TODO(), "dummy", "{}", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -98,7 +99,7 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

solved, err := flow.Solve("dummy", "", mfaEntities.JWTData{
solved, err := flow.Solve(context.TODO(), "dummy", "", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -123,7 +124,7 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

solved, err := flow.Solve("dummy2", "{}", mfaEntities.JWTData{
solved, err := flow.Solve(context.TODO(), "dummy2", "{}", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -148,9 +149,9 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

dummyChallenge.EXPECT().Request(gomock.Any()).Return(nil, nil)
dummyChallenge.EXPECT().Request(gomock.Any(), gomock.Any()).Return(nil, nil)

reqested, err := flow.Request("dummy", "{}", mfaEntities.JWTData{
reqested, err := flow.Request(context.TODO(), "dummy", "{}", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -176,7 +177,7 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

reqested, err := flow.Request("dummy2", "{}", mfaEntities.JWTData{
reqested, err := flow.Request(context.TODO(), "dummy2", "{}", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand All @@ -203,7 +204,7 @@ func TestFlow_GetChallenges(t *testing.T) {
},
}

reqested, err := flow.Request("dummy", "", mfaEntities.JWTData{
reqested, err := flow.Request(context.TODO(), "dummy", "", mfaEntities.JWTData{
Flow: "test",
Challenges: map[string]mfaEntities.Challenge{
"dummy": {
Expand Down
4 changes: 2 additions & 2 deletions flow/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
)

type IFlow interface {
Solve(challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error)
Request(challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error)
Solve(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error)
Request(ctx context.Context, challenge string, input string, JWTData mfaEntities.JWTData) (*map[string]interface{}, error)
Resolve(JWTData mfaEntities.JWTData) (*map[string]interface{}, error)
Validate(ctx context.Context, challenge string, JWTData mfaEntities.JWTData) error
GetChallenges() []string
Expand Down
2 changes: 1 addition & 1 deletion mfa/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type IMFAService interface {
Request(ctx context.Context, flow string) (*entities.MFAResult, error)
Process(ctx context.Context, jwt string, challenge string, input string, request bool) (*entities.MFAResult, error)
Process(ctx context.Context, jwt string, challenge string, input string, request bool, beforeHook *func(ctx context.Context, challenge string, input string) error) (*entities.MFAResult, error)
}
22 changes: 14 additions & 8 deletions mfa/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (m *Service) getFlow(ctx context.Context, flow string, decodedJWT *entities
return requestedFlow, nil
}

func (m *Service) Process(ctx context.Context, jwt string, challenge string, input string, request bool) (*entities.MFAResult, error) {
func (m *Service) Process(ctx context.Context, jwt string, challenge string, input string, request bool, beforeHook *func(ctx context.Context, challenge string, input string) error) (*entities.MFAResult, error) {
decodedJWT, err := m.decodeJWT(jwt)
if err != nil {
return nil, err
Expand All @@ -69,11 +69,17 @@ func (m *Service) Process(ctx context.Context, jwt string, challenge string, inp
return nil, err
}

if beforeHook != nil {
err = (*beforeHook)(ctx, challenge, input)
if err != nil {
return nil, err
}
}
if request {
return m.handleRequest(*decodedJWT, challenge, input, requestFlow)
return m.handleRequest(ctx, *decodedJWT, challenge, input, requestFlow)
}

return m.handleSolve(*decodedJWT, challenge, input, requestFlow)
return m.handleSolve(ctx, *decodedJWT, challenge, input, requestFlow)
}

func (m *Service) Request(ctx context.Context, flow string) (*entities.MFAResult, error) {
Expand All @@ -88,7 +94,7 @@ func (m *Service) Request(ctx context.Context, flow string) (*entities.MFAResult
return nil, err
}

return m.handleRequest(entities.JWTData{
return m.handleRequest(ctx, entities.JWTData{
Flow: flow,
Identifier: additionalJWTData.Identifier,
Type: additionalJWTData.Type,
Expand Down Expand Up @@ -125,8 +131,8 @@ func (m *Service) getChallengeStatus(claimStatus string) entities.Challenge {
}
}

func (m *Service) handleRequest(decodedJWT entities.JWTData, challenge string, input string, requestFlow flow.IFlow) (*entities.MFAResult, error) {
result, err := requestFlow.Request(challenge, input, decodedJWT)
func (m *Service) handleRequest(ctx context.Context, decodedJWT entities.JWTData, challenge string, input string, requestFlow flow.IFlow) (*entities.MFAResult, error) {
result, err := requestFlow.Request(ctx, challenge, input, decodedJWT)

resultJson, _ := json.Marshal(result)
resultJsonString := string(resultJson)
Expand Down Expand Up @@ -165,8 +171,8 @@ func (m *Service) handleRequest(decodedJWT entities.JWTData, challenge string, i
}, nil
}

func (m *Service) handleSolve(decodedJWT entities.JWTData, challenge string, input string, requestFlow flow.IFlow) (*entities.MFAResult, error) {
result, err := requestFlow.Solve(challenge, input, decodedJWT)
func (m *Service) handleSolve(ctx context.Context, decodedJWT entities.JWTData, challenge string, input string, requestFlow flow.IFlow) (*entities.MFAResult, error) {
result, err := requestFlow.Solve(ctx, challenge, input, decodedJWT)

resultJson, _ := json.Marshal(result)
resultJsonString := string(resultJson)
Expand Down
Loading

0 comments on commit 5468a7e

Please sign in to comment.