Skip to content

Commit

Permalink
tidying a bit the code to simplify testing syncACLS
Browse files Browse the repository at this point in the history
  • Loading branch information
ncode committed Oct 18, 2024
1 parent 9d5253d commit 4d37175
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 90 deletions.
157 changes: 67 additions & 90 deletions pkg/aclmanager/aclmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,45 +232,32 @@ func hashString(s string) string {
return fmt.Sprintf("%x", h.Sum(nil))
}

// SyncAcls connects to the primary node and syncs the ACLs to the current node
func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]string, []string, error) {
slog.Debug("Entering SyncAcls")
defer slog.Debug("Exiting SyncAcls")

if primary == nil {
err := fmt.Errorf("no primary found")
slog.Error("No primary found", "error", err)
return nil, nil, err
}
// listAndMapAcls is an auxiliary function to list ACLs and create a map of username to hash and ACL string
func listAndMapAcls(ctx context.Context, client *redis.Client) (map[string]string, map[string]string, error) {
slog.Debug("Entering listAndMapAcls")
defer slog.Debug("Exiting listAndMapAcls")

const batchSize = 100

// Map to store username to hash for source ACLs
sourceAclMap := make(map[string]string)
// Map to store username to ACL string for source ACLs (only for users needing updates)
sourceAclStrings := make(map[string]string)

// Get source ACLs
sourceResult, err := primary.RedisClient.Do(ctx, "ACL", "LIST").Result()
result, err := client.Do(ctx, "ACL", "LIST").Result()
if err != nil {
slog.Error("Failed to list source ACLs", "error", err)
return nil, nil, fmt.Errorf("SyncAcls: error listing source ACLs: %w", err)
slog.Error("Failed to list ACLs", "error", err)
return nil, nil, fmt.Errorf("listAndMapAcls: error listing ACLs: %w", err)
}

sourceAclList, ok := sourceResult.([]interface{})
aclList, ok := result.([]interface{})
if !ok {
err := fmt.Errorf("unexpected result format: %T", sourceResult)
slog.Error("Unexpected result format", "result", sourceResult)
return nil, nil, fmt.Errorf("listAcls: %w", err)
err := fmt.Errorf("unexpected result format: %T", result)
slog.Error("Unexpected result format", "result", result)
return nil, nil, fmt.Errorf("listAndMapAcls: %w", err)
}

// Process source ACLs
for _, acl := range sourceAclList {
aclHashMap := make(map[string]string)
aclStrMap := make(map[string]string)
for _, acl := range aclList {
aclStr, ok := acl.(string)
if !ok {
err := fmt.Errorf("unexpected type for ACL: %T", acl)
slog.Error("Unexpected type for ACL", "acl", acl)
return nil, nil, fmt.Errorf("listAcls: %w", err)
return nil, nil, fmt.Errorf("listAndMapAcls: %w", err)
}
fields := strings.Fields(aclStr)
if len(fields) < 2 {
Expand All @@ -279,44 +266,37 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin
}
username := fields[1]
hash := hashString(aclStr)
sourceAclMap[username] = hash
// Store the ACL string for potential updates
sourceAclStrings[username] = aclStr
aclHashMap[username] = hash
aclStrMap[username] = aclStr
}

// Map to store username to hash for destination ACLs
destinationAclMap := make(map[string]string)
slog.Info("Listed and mapped ACLs", "count", len(aclHashMap))
return aclHashMap, aclStrMap, nil
}

// Get destination ACLs
destinationResult, err := a.RedisClient.Do(ctx, "ACL", "LIST").Result()
if err != nil {
slog.Error("Failed to list destination ACLs", "error", err)
return nil, nil, fmt.Errorf("SyncAcls: error listing destination ACLs: %w", err)
// SyncAcls connects to the primary node and syncs the ACLs to the current node
func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]string, []string, error) {
slog.Debug("Entering SyncAcls")
defer slog.Debug("Exiting SyncAcls")

if primary == nil {
err := fmt.Errorf("no primary found")
slog.Error("No primary found", "error", err)
return nil, nil, err
}

destinationAclList, ok := destinationResult.([]interface{})
if !ok {
err := fmt.Errorf("unexpected result format: %T", destinationResult)
slog.Error("Unexpected result format", "result", destinationResult)
return nil, nil, fmt.Errorf("listAcls: %w", err)
const batchSize = 100

// Get source ACLs
sourceAclHashMap, sourceAclStrMap, err := listAndMapAcls(ctx, primary.RedisClient)
if err != nil {
return nil, nil, fmt.Errorf("SyncAcls: error listing source ACLs: %w", err)
}

// Process destination ACLs
for _, acl := range destinationAclList {
aclStr, ok := acl.(string)
if !ok {
err := fmt.Errorf("unexpected type for ACL: %T", acl)
slog.Error("Unexpected type for ACL", "acl", acl)
return nil, nil, fmt.Errorf("listAcls: %w", err)
}
fields := strings.Fields(aclStr)
if len(fields) < 2 {
slog.Warn("Invalid ACL format", "acl", aclStr)
continue
}
username := fields[1]
hash := hashString(aclStr)
destinationAclMap[username] = hash
// Get destination ACLs
destinationAclHashMap, _, err := listAndMapAcls(ctx, a.RedisClient)
if err != nil {
return nil, nil, fmt.Errorf("SyncAcls: error listing destination ACLs: %w", err)
}

var updated, deleted []string
Expand All @@ -326,8 +306,8 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin
pipe := a.RedisClient.Pipeline()

// Delete ACLs that are not in the source
for username := range destinationAclMap {
if _, found := sourceAclMap[username]; !found && username != "default" {
for username := range destinationAclHashMap {
if _, found := sourceAclHashMap[username]; !found && username != "default" {
slog.Debug("Deleting ACL", "username", username)
cmd := pipe.Do(ctx, "ACL", "DELUSER", username)
cmds = append(cmds, cmd)
Expand All @@ -346,39 +326,36 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin
}

// Add or update ACLs from the source
for username, sourceHash := range sourceAclMap {
destHash, found := destinationAclMap[username]
if found && destHash == sourceHash {
// ACL is already up-to-date
continue
}

aclStr := sourceAclStrings[username]
if aclStr == "" {
slog.Error("ACL string not found for user", "username", username)
continue
}
for username, sourceHash := range sourceAclHashMap {
destHash, found := destinationAclHashMap[username]
if !found || destHash != sourceHash {
aclStr := sourceAclStrMap[username]
if aclStr == "" {
slog.Error("ACL string not found for user", "username", username)
continue

Check warning on line 335 in pkg/aclmanager/aclmanager.go

View check run for this annotation

Codecov / codecov/patch

pkg/aclmanager/aclmanager.go#L334-L335

Added lines #L334 - L335 were not covered by tests
}

args := []interface{}{"ACL", "SETUSER"}
fields := strings.Fields(aclStr)
// Skip the "user" keyword
for _, field := range fields[1:] {
args = append(args, field)
}
args := []interface{}{"ACL", "SETUSER"}
fields := strings.Fields(aclStr)
// Skip the "user" keyword
for _, field := range fields[1:] {
args = append(args, field)
}

cmd := pipe.Do(ctx, args...)
cmds = append(cmds, cmd)
updated = append(updated, username)
cmd := pipe.Do(ctx, args...)
cmds = append(cmds, cmd)
updated = append(updated, username)

if len(cmds) >= batchSize {
// Execute pipeline
if _, err = pipe.Exec(ctx); err != nil {
slog.Error("Failed to execute pipeline", "error", err)
return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err)
if len(cmds) >= batchSize {
// Execute pipeline
if _, err = pipe.Exec(ctx); err != nil {
slog.Error("Failed to execute pipeline", "error", err)
return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err)
}

Check warning on line 354 in pkg/aclmanager/aclmanager.go

View check run for this annotation

Codecov / codecov/patch

pkg/aclmanager/aclmanager.go#L350-L354

Added lines #L350 - L354 were not covered by tests
// Reset pipeline and cmds
pipe = a.RedisClient.Pipeline()
cmds = cmds[:0]

Check warning on line 357 in pkg/aclmanager/aclmanager.go

View check run for this annotation

Codecov / codecov/patch

pkg/aclmanager/aclmanager.go#L356-L357

Added lines #L356 - L357 were not covered by tests
}
// Reset pipeline and cmds
pipe = a.RedisClient.Pipeline()
cmds = cmds[:0]
}
}

Expand Down
94 changes: 94 additions & 0 deletions pkg/aclmanager/aclmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,100 @@ func TestListAcls(t *testing.T) {
}
}

func TestListAndMapAcls(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mockResp interface{}
expectedHashMap map[string]string
expectedStrMap map[string]string
wantErr bool
expectedErrMsg string
}{
{
name: "valid ACL list",
mockResp: []interface{}{
"user default on nopass ~* &* +@all",
"user alice on >password ~keys:* -@all +get +set +del",
},
expectedHashMap: map[string]string{
"default": hashString("user default on nopass ~* &* +@all"),
"alice": hashString("user alice on >password ~keys:* -@all +get +set +del"),
},
expectedStrMap: map[string]string{
"default": "user default on nopass ~* &* +@all",
"alice": "user alice on >password ~keys:* -@all +get +set +del",
},
wantErr: false,
},
{
name: "empty ACL list",
mockResp: []interface{}{},
expectedHashMap: map[string]string{},
expectedStrMap: map[string]string{},
wantErr: false,
},
{
name: "error from Redis client",
mockResp: nil,
wantErr: true,
expectedErrMsg: "error listing ACLs",
},
{
name: "invalid ACL format",
mockResp: []interface{}{
"invalid_acl",
"user alice on >password ~keys:* -@all +get +set +del",
},
expectedHashMap: map[string]string{
"alice": hashString("user alice on >password ~keys:* -@all +get +set +del"),
},
expectedStrMap: map[string]string{
"alice": "user alice on >password ~keys:* -@all +get +set +del",
},
wantErr: false,
},
{
name: "result is not []interface{}",
mockResp: "invalid_type",
wantErr: true,
expectedErrMsg: "unexpected result format",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redisClient, mock := redismock.NewClientMock()

if tt.wantErr && tt.mockResp == nil {
mock.ExpectDo("ACL", "LIST").SetErr(fmt.Errorf("error"))
} else {
mock.ExpectDo("ACL", "LIST").SetVal(tt.mockResp)
}

aclHashMap, aclStrMap, err := listAndMapAcls(context.Background(), redisClient)

if (err != nil) != tt.wantErr {
t.Errorf("listAndMapAcls() error = %v, wantErr %v", err, tt.wantErr)
return
}

if tt.wantErr {
assert.Error(t, err)
if tt.expectedErrMsg != "" {
assert.Contains(t, err.Error(), tt.expectedErrMsg)
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedHashMap, aclHashMap)
assert.Equal(t, tt.expectedStrMap, aclStrMap)
}

assert.NoError(t, mock.ExpectationsWereMet())
})
}
}

func TestSyncAcls(t *testing.T) {
t.Parallel()
tests := []struct {
Expand Down

0 comments on commit 4d37175

Please sign in to comment.