diff --git a/cmd/api/src/api/v2/auth/oidc.go b/cmd/api/src/api/v2/auth/oidc.go index 4132ed456..e1e0d0a04 100644 --- a/cmd/api/src/api/v2/auth/oidc.go +++ b/cmd/api/src/api/v2/auth/oidc.go @@ -46,11 +46,12 @@ var ( ) type oidcClaims struct { - Name string `json:"name"` - FamilyName string `json:"family_name"` - DisplayName string `json:"given_name"` - Email string `json:"email"` - Verified bool `json:"email_verified"` + Name string `json:"name"` + FamilyName string `json:"family_name"` + DisplayName string `json:"given_name"` + Email string `json:"email"` + Verified bool `json:"email_verified"` + Roles []string `json:"roles"` } // UpsertOIDCProviderRequest represents the body of create & update provider endpoints @@ -255,15 +256,17 @@ func getOIDCClaims(reqCtx context.Context, provider *oidc.Provider, ssoProvider } func jitOIDCUserCreation(ctx context.Context, ssoProvider model.SSOProvider, claims oidcClaims, u jitUserCreator) error { - if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil { - return fmt.Errorf("get role: %v", err) + if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil { + return fmt.Errorf("sanitizeAndGetRoles: %v", err) + } else if len(roles) != 1 { + return fmt.Errorf("invalid roles %v", roles.Names()) } else if _, err := u.LookupUser(ctx, claims.Email); err != nil && !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("lookup user: %v", err) } else if errors.Is(err, database.ErrNotFound) { var user = model.User{ EmailAddress: null.StringFrom(claims.Email), PrincipalName: claims.Email, - Roles: model.Roles{role}, + Roles: roles, SSOProviderID: null.Int32From(ssoProvider.ID), EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users FirstName: null.StringFrom(claims.Email), diff --git a/cmd/api/src/api/v2/auth/saml.go b/cmd/api/src/api/v2/auth/saml.go index bc14b1256..2e37702b1 100644 --- a/cmd/api/src/api/v2/auth/saml.go +++ b/cmd/api/src/api/v2/auth/saml.go @@ -473,15 +473,17 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re } func jitSAMLUserCreation(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserCreator) error { - if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil { - return fmt.Errorf("get role: %v", err) + if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil { + return fmt.Errorf("sanitizeAndGetRoles: %v", err) + } else if len(roles) != 1 { + return fmt.Errorf("invalid roles %v", roles.Names()) } else if _, err := u.LookupUser(ctx, principalName); err != nil && !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("lookup user: %v", err) } else if errors.Is(err, database.ErrNotFound) { user := model.User{ EmailAddress: null.StringFrom(principalName), PrincipalName: principalName, - Roles: model.Roles{role}, + Roles: roles, SSOProviderID: null.Int32From(ssoProvider.ID), EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users FirstName: null.StringFrom(principalName), diff --git a/cmd/api/src/api/v2/auth/sso.go b/cmd/api/src/api/v2/auth/sso.go index 7f5032382..a41a2b937 100644 --- a/cmd/api/src/api/v2/auth/sso.go +++ b/cmd/api/src/api/v2/auth/sso.go @@ -18,6 +18,7 @@ package auth import ( "context" + "fmt" "net/http" "net/url" "path" @@ -25,6 +26,8 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/specterops/bloodhound/dawgs/cardinality" + "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/api" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/ctx" @@ -59,8 +62,12 @@ type getRoler interface { GetRole(ctx context.Context, roleID int32) (model.Role, error) } +type getAllRoler interface { + GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error) +} + type jitUserCreator interface { - getRoler + getAllRoler LookupUser(ctx context.Context, principalNameOrEmail string) (model.User, error) CreateUser(ctx context.Context, user model.User) (model.User, error) @@ -242,3 +249,46 @@ func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, req } } } + +func sanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.SSOProviderAutoProvisionConfig, maybeBHRoles []string, r getAllRoler) (model.Roles, error) { + if dbRoles, err := r.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil { + return nil, err + } else { + var defaultRole model.Role + dbRolesBySlug := make(map[string]*model.Role) + // Make quick lookup by role slug -> lower cased, dashes for spaces, and prefixed by `bh` e.g. bh-power-user + for _, r := range dbRoles { + dbRolesBySlug[fmt.Sprintf("bh-%s", strings.ReplaceAll(strings.ToLower(r.Name), " ", "-"))] = &r + if r.ID == autoProvisionConfig.DefaultRoleId { + defaultRole = r + } + } + + if autoProvisionConfig.RoleProvision { + var validRoles model.Roles + validRolesSeen := cardinality.NewBitmap32() // Ensure no dupes + // Only add valid roles + for _, r := range maybeBHRoles { + if dbRole := dbRolesBySlug[strings.ReplaceAll(strings.ToLower(r), " ", "-")]; dbRole != nil && !validRolesSeen.Contains(uint32(dbRole.ID)) { + validRoles = append(validRoles, *dbRole) + validRolesSeen.Add(uint32(dbRole.ID)) + } + } + switch { + case len(validRoles) == 1: + return validRoles, nil + case len(validRoles) > 1: + log.Warnf("[SSO] JIT Role Provision detected multiple valid roles - %s , falling back to default role %s", validRoles.Names(), defaultRole.Name) + default: + log.Warnf("[SSO] JIT Role Provision detected no valid roles from %s , falling back to default role %s", maybeBHRoles, defaultRole.Name) + } + } + + /* Fallback to default role: + - Role provision is disabled + - Role provision is enabled but no valid roles are found + - Role provision is enabled but multiple valid roles are found + */ + return model.Roles{defaultRole}, nil + } +} diff --git a/cmd/api/src/model/auth.go b/cmd/api/src/model/auth.go index b86b49328..717005003 100644 --- a/cmd/api/src/model/auth.go +++ b/cmd/api/src/model/auth.go @@ -345,6 +345,16 @@ func (s Roles) GetValidFilterPredicatesAsStrings(column string) ([]string, error } } +func (s Roles) Names() []string { + names := make([]string, len(s)) + + for idx, role := range s { + names[idx] = role.Name + } + + return names +} + func (s Roles) IDs() []int32 { ids := make([]int32, len(s)) diff --git a/cmd/api/src/model/samlprovider.go b/cmd/api/src/model/samlprovider.go index 7d9742713..a851d140a 100644 --- a/cmd/api/src/model/samlprovider.go +++ b/cmd/api/src/model/samlprovider.go @@ -39,6 +39,7 @@ const ( XMLSOAPClaimsGivenName = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname" XMLSOAPClaimsName = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name" XMLSOAPClaimsSurname = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname" + MicrosoftClaimsRole = "http://schemas.microsoft.com/ws/2008/06/identity/claims/role" ) var ( @@ -118,6 +119,11 @@ func (s SAMLProvider) givenNameAttributeNames() []string { return []string{ObjectIDGivenName, XMLSOAPClaimsGivenName, ObjectIDName, XMLSOAPClaimsName} } +func (s SAMLProvider) roleAttributeNames() []string { + // Added the MicrosoftClaimsRole as a fallback + return []string{MicrosoftClaimsRole} +} + func (s SAMLProvider) surnameAttributeNames() []string { return []string{ObjectIDSurname, XMLSOAPClaimsSurname} } @@ -165,6 +171,23 @@ func (s SAMLProvider) GetSAMLUserGivenNameFromAssertion(assertion *saml.Assertio return assertionFindString(assertion, s.givenNameAttributeNames()...) } +// GetSAMLUserRolesFromAssertion May be empty if not present +func (s SAMLProvider) GetSAMLUserRolesFromAssertion(assertion *saml.Assertion) (roles []string) { + for _, attributeStatement := range assertion.AttributeStatements { + for _, attribute := range attributeStatement.Attributes { + for _, validName := range s.roleAttributeNames() { + if attribute.Name == validName && len(attribute.Values) > 0 { + for _, value := range attribute.Values { + roles = append(roles, value.Value) + } + } + } + } + } + + return roles +} + func (s SAMLProvider) GetSAMLUserSurnameFromAssertion(assertion *saml.Assertion) (string, error) { return assertionFindString(assertion, s.surnameAttributeNames()...) }