Skip to content

Commit

Permalink
feat: Apply only the highest priority role during role provision
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Dec 27, 2024
1 parent 81d2850 commit 2d8e918
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
15 changes: 11 additions & 4 deletions cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"net/url"
"path"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -232,11 +233,12 @@ func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, req
}
}

func (s ManagementResource) sanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.AutoProvision, maybeBHRoles []string) (validRoles model.Roles, err error) {
func (s ManagementResource) sanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.AutoProvision, maybeBHRoles []string) (roles model.Roles, err error) {
if dbRoles, err := s.db.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil {
return nil, err
} else {
if autoProvisionConfig.RoleProvision && len(maybeBHRoles) > 0 {
var validRoles model.Roles
// Make quick lookup by role slug -> lower cased, dashes for spaces, and prefixed by `bh` e.g. bh-power-user
dbRolesBySlug := make(map[string]*model.Role)
validRolesSeen := cardinality.NewBitmap32() // Ensure no dupes
Expand All @@ -250,18 +252,23 @@ func (s ManagementResource) sanitizeAndGetRoles(ctx context.Context, autoProvisi
validRolesSeen.Add(uint32(dbRole.ID))
}
}
// Sort by priority to find the "highest" role
if len(validRoles) > 0 {
sort.SliceStable(validRoles, func(i, j int) bool { return validRoles[i].Priority < validRoles[j].Priority })
roles = model.Roles{validRoles[0]}
}
}

// Fallback to default role if none found or role provision is disabled
if len(validRoles) == 0 {
if len(roles) == 0 {
for _, role := range dbRoles {
if role.ID == autoProvisionConfig.DefaultRole {
validRoles = append(validRoles, role)
roles = append(roles, role)
break
}
}
}

return validRoles, nil
return roles, nil
}
}
27 changes: 27 additions & 0 deletions cmd/api/src/database/migration/migrations/v6.4.0.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Copyright 2024 Specter Ops, Inc.
--
-- Licensed under the Apache License, Version 2.0
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
-- SPDX-License-Identifier: Apache-2.0

-- Add priority column to roles table
ALTER TABLE ONLY roles
ADD COLUMN IF NOT EXISTS priority INTEGER;

-- Set priorities for roles
UPDATE roles SET priority = 1 WHERE id = (SELECT id FROM roles WHERE roles.name = 'Administrator');
UPDATE roles SET priority = 2 WHERE id = (SELECT id FROM roles WHERE roles.name = 'Power User');
UPDATE roles SET priority = 3 WHERE id = (SELECT id FROM roles WHERE roles.name = 'User');
UPDATE roles SET priority = 4 WHERE id = (SELECT id FROM roles WHERE roles.name = 'Read-Only');
UPDATE roles SET priority = 5 WHERE id = (SELECT id FROM roles WHERE roles.name = 'Upload-Only');

2 changes: 2 additions & 0 deletions cmd/api/src/model/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ type Role struct {
Name string `json:"name"`
Description string `json:"description"`
Permissions Permissions `json:"permissions" gorm:"many2many:roles_permissions"`
Priority int `json:"priority"`

Serial
}
Expand All @@ -302,6 +303,7 @@ func (s Roles) IsSortable(column string) bool {
case "name",
"description",
"id",
"priority",
"created_at",
"updated_at",
"deleted_at":
Expand Down
1 change: 0 additions & 1 deletion cmd/api/src/model/samlprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ func (s SAMLProvider) GetSAMLUserRolesFromAssertion(assertion *saml.Assertion) (
for _, attribute := range attributeStatement.Attributes {
for _, validName := range s.roleAttributeNames() {
if attribute.Name == validName && len(attribute.Values) > 0 {
// Try to find an explicit XMLType of xs:string
for _, value := range attribute.Values {
roles = append(roles, value.Value)
}
Expand Down

0 comments on commit 2d8e918

Please sign in to comment.