Skip to content

Commit

Permalink
fix: tests + formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Jan 2, 2025
1 parent 07bf572 commit 494eb5e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cmd/api/src/api/v2/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ func getOIDCClaims(reqCtx context.Context, provider *oidc.Provider, ssoProvider
}

func jitOIDCUserCreation(ctx context.Context, ssoProvider model.SSOProvider, claims oidcClaims, u jitUserCreator) error {
if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil {
return fmt.Errorf("sanitizeAndGetRoles: %v", err)
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles %v", roles.Names())
} else if _, err := u.LookupUser(ctx, claims.Email); err != nil && !errors.Is(err, database.ErrNotFound) {
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/api/v2/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
}

func jitSAMLUserCreation(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserCreator) error {
if roles, err := sanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil {
return fmt.Errorf("sanitizeAndGetRoles: %v", err)
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles %v", roles.Names())
} else if _, err := u.LookupUser(ctx, principalName); err != nil && !errors.Is(err, database.ErrNotFound) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, req
}
}

func sanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.SSOProviderAutoProvisionConfig, maybeBHRoles []string, r getAllRoler) (model.Roles, error) {
func SanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.SSOProviderAutoProvisionConfig, maybeBHRoles []string, r getAllRoler) (model.Roles, error) {
if dbRoles, err := r.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil {
return nil, err
} else {
Expand Down
49 changes: 49 additions & 0 deletions cmd/api/src/api/v2/auth/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package auth_test

import (
"context"
"net/http"
"net/url"
"testing"
Expand All @@ -31,6 +32,7 @@ import (
"github.com/specterops/bloodhound/src/database/types/null"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/utils/test"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

Expand Down Expand Up @@ -247,3 +249,50 @@ func TestManagementResource_DeleteOIDCProvider(t *testing.T) {
ResponseStatusCode(http.StatusNotFound)
})
}

func TestManagementResource_SanitizeAndGetRoles(t *testing.T) {
var (
mockCtrl = gomock.NewController(t)
_, mockDB = apitest.NewAuthManagementResource(mockCtrl)
testCtx = context.Background()

dbRoles = model.Roles{
{Name: "God Role", Serial: model.Serial{ID: 1}},
{Name: "Default Role", Serial: model.Serial{ID: 2}},
{Name: "Valid Role", Serial: model.Serial{ID: 3}},
}
roleProvisionEnabledConfig = model.SSOProviderAutoProvisionConfig{RoleProvision: true, DefaultRoleId: 2, Enabled: true}
roleProvisionDisabledConfig = model.SSOProviderAutoProvisionConfig{RoleProvision: false, DefaultRoleId: 2, Enabled: true}
)
t.Run("role provision enabled - return valid role", func(t *testing.T) {
mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{}).Return(dbRoles, nil)
roles, err := auth.SanitizeAndGetRoles(testCtx, roleProvisionEnabledConfig, []string{"ignored", "bh-valid-role"}, mockDB)
require.Nil(t, err)
require.Len(t, roles, 1)
require.Equal(t, roles[0].ID, dbRoles[2].ID)
})

t.Run("role provision enabled - return default role when multiple valid roles", func(t *testing.T) {
mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{}).Return(dbRoles, nil)
roles, err := auth.SanitizeAndGetRoles(testCtx, roleProvisionEnabledConfig, []string{"bh-valid-role", "ignored", "bh-god-role"}, mockDB)
require.Nil(t, err)
require.Len(t, roles, 1)
require.Equal(t, roles[0].ID, roleProvisionEnabledConfig.DefaultRoleId)
})

t.Run("role provision enabled - return default role when no valid roles", func(t *testing.T) {
mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{}).Return(dbRoles, nil)
roles, err := auth.SanitizeAndGetRoles(testCtx, roleProvisionEnabledConfig, []string{"bh-invalid-role", "ignored"}, mockDB)
require.Nil(t, err)
require.Len(t, roles, 1)
require.Equal(t, roles[0].ID, roleProvisionEnabledConfig.DefaultRoleId)
})

t.Run("role provision disabled - return default role", func(t *testing.T) {
mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{}).Return(dbRoles, nil)
roles, err := auth.SanitizeAndGetRoles(testCtx, roleProvisionDisabledConfig, []string{"bh-valid-role", "ignored", "bh-god-role"}, mockDB)
require.Nil(t, err)
require.Len(t, roles, 1)
require.Equal(t, roles[0].ID, roleProvisionEnabledConfig.DefaultRoleId)
})
}

0 comments on commit 494eb5e

Please sign in to comment.