From cb1ac2a6f6a390b2985650f1ae8a9653f553567c Mon Sep 17 00:00:00 2001 From: joel Date: Sat, 28 Sep 2024 11:30:40 +0200 Subject: [PATCH] fix: add option to override on user --- internal/api/mfa.go | 49 ++++++++++++++++++++++++++++++++++++++--- internal/models/user.go | 17 +++++++++----- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index d29abe40d..e295895c3 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -75,12 +75,55 @@ type UnenrollFactorResponse struct { type WebAuthnParams struct { RPID string `json:"rp_id,omitempty"` - // TODO: reconcile this later + // Can encode multiple origins as comma separated values like: "origin1,origin2" RPOrigins string `json:"rp_origins,omitempty"` AssertionResponse *wbnprotocol.CredentialAssertionResponse `json:"assertion_response,omitempty"` CreationResponse *wbnprotocol.CredentialCreationResponse `json:"creation_response,omitempty"` } +func (w *WebAuthnParams) GetRPOrigins() []string { + if w.RPOrigins == "" { + return nil + } + return strings.Split(w.RPOrigins, ",") +} + +func (w *WebAuthnParams) ToConfig() (*webauthn.WebAuthn, error) { + if w.RPID == "" { + return nil, fmt.Errorf("WebAuthn RP ID cannot be empty") + } + + origins := w.GetRPOrigins() + if len(origins) == 0 { + return nil, fmt.Errorf("WebAuthn RP Origins cannot be empty") + } + + var validOrigins []string + var invalidOrigins []string + + for _, origin := range origins { + parsedURL, err := url.Parse(origin) + if err != nil || (parsedURL.Scheme != "https" && parsedURL.Scheme != "http") || parsedURL.Host == "" { + invalidOrigins = append(invalidOrigins, origin) + } else { + validOrigins = append(validOrigins, origin) + } + } + + if len(invalidOrigins) > 0 { + return nil, fmt.Errorf("Invalid RP origins: %s", strings.Join(invalidOrigins, ", ")) + } + + // TODO: Find a more sensible default + wconfig := &webauthn.Config{ + RPDisplayName: w.RPID, + RPID: w.RPID, + RPOrigins: validOrigins, + } + + return webauthn.New(wconfig) +} + const ( QRCodeGenerationErrorMessage = "Error generating QR Code" ) @@ -525,7 +568,7 @@ func (a *API) challengeWebAuthnFactor(w http.ResponseWriter, r *http.Request) er if params.WebAuthn == nil { return badRequestError(ErrorCodeValidationFailed, "WebAuthn config required") } - webAuthn, err := validateWebAuthnConfig(params.WebAuthn) + webAuthn, err := params.WebAuthn.ToConfig() if err != nil { return err } @@ -880,7 +923,7 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param case factor.IsUnverified() && params.WebAuthn.CreationResponse == nil: return badRequestError(ErrorCodeValidationFailed, "WebAuthn Creation Response required to login") default: - webAuthn, err = validateWebAuthnConfig(params.WebAuthn) + webAuthn, err = params.WebAuthn.ToConfig() if err != nil { return err } diff --git a/internal/models/user.go b/internal/models/user.go index a289977b0..1abee0c34 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -933,19 +933,26 @@ func (u *User) FindOwnedFactorByID(tx *storage.Connection, factorID uuid.UUID) ( } return &factor, nil } + func (user *User) WebAuthnID() []byte { - // TODO: confirm this user ID + if webAuthnID, ok := user.UserMetaData["web_authn_id"].(string); ok && webAuthnID != "" { + return []byte(webAuthnID) + } return []byte(user.ID.String()) } func (user *User) WebAuthnName() string { - // TODO: Allow for overrides on this - return string(user.Email) + if webAuthnName, ok := user.UserMetaData["web_authn_name"].(string); ok && webAuthnName != "" { + return webAuthnName + } + return user.Email.String() } func (user *User) WebAuthnDisplayName() string { - // TODO: - return string(user.Email) + if webAuthnDisplayName, ok := user.UserMetaData["web_authn_display_name"].(string); ok && webAuthnDisplayName != "" { + return webAuthnDisplayName + } + return user.Email.String() } func (user *User) WebAuthnCredentials() []webauthn.Credential {