Skip to content

Commit

Permalink
feat: Add role provision support
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Jan 2, 2025
1 parent ee99753 commit 01f1702
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 12 deletions.
19 changes: 11 additions & 8 deletions cmd/api/src/api/v2/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ var (
)

type oidcClaims struct {
Name string `json:"name"`
FamilyName string `json:"family_name"`
DisplayName string `json:"given_name"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
Name string `json:"name"`
FamilyName string `json:"family_name"`
DisplayName string `json:"given_name"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
Roles []string `json:"roles"`
}

// UpsertOIDCProviderRequest represents the body of create & update provider endpoints
Expand Down Expand Up @@ -255,15 +256,17 @@ func getOIDCClaims(reqCtx context.Context, provider *oidc.Provider, ssoProvider
}

func jitOIDCUserCreation(ctx context.Context, ssoProvider model.SSOProvider, claims oidcClaims, u jitUserCreator) error {
if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil {
return fmt.Errorf("get role: %v", err)
if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil {
return fmt.Errorf("sanitizeAndGetRoles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles %v", roles.Names())
} else if _, err := u.LookupUser(ctx, claims.Email); err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
var user = model.User{
EmailAddress: null.StringFrom(claims.Email),
PrincipalName: claims.Email,
Roles: model.Roles{role},
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(claims.Email),
Expand Down
8 changes: 5 additions & 3 deletions cmd/api/src/api/v2/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,17 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
}

func jitSAMLUserCreation(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserCreator) error {
if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil {
return fmt.Errorf("get role: %v", err)
if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil {
return fmt.Errorf("sanitizeAndGetRoles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles %v", roles.Names())
} else if _, err := u.LookupUser(ctx, principalName); err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
user := model.User{
EmailAddress: null.StringFrom(principalName),
PrincipalName: principalName,
Roles: model.Roles{role},
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(principalName),
Expand Down
52 changes: 51 additions & 1 deletion cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ package auth

import (
"context"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"

"github.com/gorilla/mux"
"github.com/specterops/bloodhound/dawgs/cardinality"
"github.com/specterops/bloodhound/log"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/ctx"
Expand Down Expand Up @@ -59,8 +62,12 @@ type getRoler interface {
GetRole(ctx context.Context, roleID int32) (model.Role, error)
}

type getAllRoler interface {
GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error)
}

type jitUserCreator interface {
getRoler
getAllRoler

LookupUser(ctx context.Context, principalNameOrEmail string) (model.User, error)
CreateUser(ctx context.Context, user model.User) (model.User, error)
Expand Down Expand Up @@ -242,3 +249,46 @@ func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, req
}
}
}

func sanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.SSOProviderAutoProvisionConfig, maybeBHRoles []string, r getAllRoler) (model.Roles, error) {
if dbRoles, err := r.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil {
return nil, err
} else {
var defaultRole model.Role
dbRolesBySlug := make(map[string]*model.Role)
// Make quick lookup by role slug -> lower cased, dashes for spaces, and prefixed by `bh` e.g. bh-power-user
for _, r := range dbRoles {
dbRolesBySlug[fmt.Sprintf("bh-%s", strings.ReplaceAll(strings.ToLower(r.Name), " ", "-"))] = &r
if r.ID == autoProvisionConfig.DefaultRoleId {
defaultRole = r
}
}

if autoProvisionConfig.RoleProvision {
var validRoles model.Roles
validRolesSeen := cardinality.NewBitmap32() // Ensure no dupes
// Only add valid roles
for _, r := range maybeBHRoles {
if dbRole := dbRolesBySlug[strings.ReplaceAll(strings.ToLower(r), " ", "-")]; dbRole != nil && !validRolesSeen.Contains(uint32(dbRole.ID)) {
validRoles = append(validRoles, *dbRole)
validRolesSeen.Add(uint32(dbRole.ID))
}
}
switch {
case len(validRoles) == 1:
return validRoles, nil
case len(validRoles) > 1:
log.Warnf("[SSO] JIT Role Provision detected multiple valid roles - %s , falling back to default role %s", validRoles.Names(), defaultRole.Name)
default:
log.Warnf("[SSO] JIT Role Provision detected no valid roles from %s , falling back to default role %s", maybeBHRoles, defaultRole.Name)
}
}

/* Fallback to default role:
- Role provision is disabled
- Role provision is enabled but no valid roles are found
- Role provision is enabled but multiple valid roles are found
*/
return model.Roles{defaultRole}, nil
}
}
10 changes: 10 additions & 0 deletions cmd/api/src/model/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,16 @@ func (s Roles) GetValidFilterPredicatesAsStrings(column string) ([]string, error
}
}

func (s Roles) Names() []string {
names := make([]string, len(s))

for idx, role := range s {
names[idx] = role.Name
}

return names
}

func (s Roles) IDs() []int32 {
ids := make([]int32, len(s))

Expand Down
23 changes: 23 additions & 0 deletions cmd/api/src/model/samlprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
XMLSOAPClaimsGivenName = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"
XMLSOAPClaimsName = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"
XMLSOAPClaimsSurname = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"
MicrosoftClaimsRole = "http://schemas.microsoft.com/ws/2008/06/identity/claims/role"
)

var (
Expand Down Expand Up @@ -118,6 +119,11 @@ func (s SAMLProvider) givenNameAttributeNames() []string {
return []string{ObjectIDGivenName, XMLSOAPClaimsGivenName, ObjectIDName, XMLSOAPClaimsName}
}

func (s SAMLProvider) roleAttributeNames() []string {
// Added the MicrosoftClaimsRole as a fallback
return []string{MicrosoftClaimsRole}
}

func (s SAMLProvider) surnameAttributeNames() []string {
return []string{ObjectIDSurname, XMLSOAPClaimsSurname}
}
Expand Down Expand Up @@ -165,6 +171,23 @@ func (s SAMLProvider) GetSAMLUserGivenNameFromAssertion(assertion *saml.Assertio
return assertionFindString(assertion, s.givenNameAttributeNames()...)
}

// GetSAMLUserRolesFromAssertion May be empty if not present
func (s SAMLProvider) GetSAMLUserRolesFromAssertion(assertion *saml.Assertion) (roles []string) {
for _, attributeStatement := range assertion.AttributeStatements {
for _, attribute := range attributeStatement.Attributes {
for _, validName := range s.roleAttributeNames() {
if attribute.Name == validName && len(attribute.Values) > 0 {
for _, value := range attribute.Values {
roles = append(roles, value.Value)
}
}
}
}
}

return roles
}

func (s SAMLProvider) GetSAMLUserSurnameFromAssertion(assertion *saml.Assertion) (string, error) {
return assertionFindString(assertion, s.surnameAttributeNames()...)
}
Expand Down

0 comments on commit 01f1702

Please sign in to comment.