diff --git a/pkg/aclmanager/aclmanager.go b/pkg/aclmanager/aclmanager.go index 93b536a..344a9ad 100644 --- a/pkg/aclmanager/aclmanager.go +++ b/pkg/aclmanager/aclmanager.go @@ -232,45 +232,32 @@ func hashString(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } -// SyncAcls connects to the primary node and syncs the ACLs to the current node -func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]string, []string, error) { - slog.Debug("Entering SyncAcls") - defer slog.Debug("Exiting SyncAcls") - - if primary == nil { - err := fmt.Errorf("no primary found") - slog.Error("No primary found", "error", err) - return nil, nil, err - } +// listAndMapAcls is an auxiliary function to list ACLs and create a map of username to hash and ACL string +func listAndMapAcls(ctx context.Context, client *redis.Client) (map[string]string, map[string]string, error) { + slog.Debug("Entering listAndMapAcls") + defer slog.Debug("Exiting listAndMapAcls") - const batchSize = 100 - - // Map to store username to hash for source ACLs - sourceAclMap := make(map[string]string) - // Map to store username to ACL string for source ACLs (only for users needing updates) - sourceAclStrings := make(map[string]string) - - // Get source ACLs - sourceResult, err := primary.RedisClient.Do(ctx, "ACL", "LIST").Result() + result, err := client.Do(ctx, "ACL", "LIST").Result() if err != nil { - slog.Error("Failed to list source ACLs", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error listing source ACLs: %w", err) + slog.Error("Failed to list ACLs", "error", err) + return nil, nil, fmt.Errorf("listAndMapAcls: error listing ACLs: %w", err) } - sourceAclList, ok := sourceResult.([]interface{}) + aclList, ok := result.([]interface{}) if !ok { - err := fmt.Errorf("unexpected result format: %T", sourceResult) - slog.Error("Unexpected result format", "result", sourceResult) - return nil, nil, fmt.Errorf("listAcls: %w", err) + err := fmt.Errorf("unexpected result format: %T", result) + slog.Error("Unexpected result format", "result", result) + return nil, nil, fmt.Errorf("listAndMapAcls: %w", err) } - // Process source ACLs - for _, acl := range sourceAclList { + aclHashMap := make(map[string]string) + aclStrMap := make(map[string]string) + for _, acl := range aclList { aclStr, ok := acl.(string) if !ok { err := fmt.Errorf("unexpected type for ACL: %T", acl) slog.Error("Unexpected type for ACL", "acl", acl) - return nil, nil, fmt.Errorf("listAcls: %w", err) + return nil, nil, fmt.Errorf("listAndMapAcls: %w", err) } fields := strings.Fields(aclStr) if len(fields) < 2 { @@ -279,44 +266,37 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin } username := fields[1] hash := hashString(aclStr) - sourceAclMap[username] = hash - // Store the ACL string for potential updates - sourceAclStrings[username] = aclStr + aclHashMap[username] = hash + aclStrMap[username] = aclStr } - // Map to store username to hash for destination ACLs - destinationAclMap := make(map[string]string) + slog.Info("Listed and mapped ACLs", "count", len(aclHashMap)) + return aclHashMap, aclStrMap, nil +} - // Get destination ACLs - destinationResult, err := a.RedisClient.Do(ctx, "ACL", "LIST").Result() - if err != nil { - slog.Error("Failed to list destination ACLs", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error listing destination ACLs: %w", err) +// SyncAcls connects to the primary node and syncs the ACLs to the current node +func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]string, []string, error) { + slog.Debug("Entering SyncAcls") + defer slog.Debug("Exiting SyncAcls") + + if primary == nil { + err := fmt.Errorf("no primary found") + slog.Error("No primary found", "error", err) + return nil, nil, err } - destinationAclList, ok := destinationResult.([]interface{}) - if !ok { - err := fmt.Errorf("unexpected result format: %T", destinationResult) - slog.Error("Unexpected result format", "result", destinationResult) - return nil, nil, fmt.Errorf("listAcls: %w", err) + const batchSize = 100 + + // Get source ACLs + sourceAclHashMap, sourceAclStrMap, err := listAndMapAcls(ctx, primary.RedisClient) + if err != nil { + return nil, nil, fmt.Errorf("SyncAcls: error listing source ACLs: %w", err) } - // Process destination ACLs - for _, acl := range destinationAclList { - aclStr, ok := acl.(string) - if !ok { - err := fmt.Errorf("unexpected type for ACL: %T", acl) - slog.Error("Unexpected type for ACL", "acl", acl) - return nil, nil, fmt.Errorf("listAcls: %w", err) - } - fields := strings.Fields(aclStr) - if len(fields) < 2 { - slog.Warn("Invalid ACL format", "acl", aclStr) - continue - } - username := fields[1] - hash := hashString(aclStr) - destinationAclMap[username] = hash + // Get destination ACLs + destinationAclHashMap, _, err := listAndMapAcls(ctx, a.RedisClient) + if err != nil { + return nil, nil, fmt.Errorf("SyncAcls: error listing destination ACLs: %w", err) } var updated, deleted []string @@ -326,8 +306,8 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin pipe := a.RedisClient.Pipeline() // Delete ACLs that are not in the source - for username := range destinationAclMap { - if _, found := sourceAclMap[username]; !found && username != "default" { + for username := range destinationAclHashMap { + if _, found := sourceAclHashMap[username]; !found && username != "default" { slog.Debug("Deleting ACL", "username", username) cmd := pipe.Do(ctx, "ACL", "DELUSER", username) cmds = append(cmds, cmd) @@ -346,39 +326,36 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin } // Add or update ACLs from the source - for username, sourceHash := range sourceAclMap { - destHash, found := destinationAclMap[username] - if found && destHash == sourceHash { - // ACL is already up-to-date - continue - } - - aclStr := sourceAclStrings[username] - if aclStr == "" { - slog.Error("ACL string not found for user", "username", username) - continue - } + for username, sourceHash := range sourceAclHashMap { + destHash, found := destinationAclHashMap[username] + if !found || destHash != sourceHash { + aclStr := sourceAclStrMap[username] + if aclStr == "" { + slog.Error("ACL string not found for user", "username", username) + continue + } - args := []interface{}{"ACL", "SETUSER"} - fields := strings.Fields(aclStr) - // Skip the "user" keyword - for _, field := range fields[1:] { - args = append(args, field) - } + args := []interface{}{"ACL", "SETUSER"} + fields := strings.Fields(aclStr) + // Skip the "user" keyword + for _, field := range fields[1:] { + args = append(args, field) + } - cmd := pipe.Do(ctx, args...) - cmds = append(cmds, cmd) - updated = append(updated, username) + cmd := pipe.Do(ctx, args...) + cmds = append(cmds, cmd) + updated = append(updated, username) - if len(cmds) >= batchSize { - // Execute pipeline - if _, err = pipe.Exec(ctx); err != nil { - slog.Error("Failed to execute pipeline", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err) + if len(cmds) >= batchSize { + // Execute pipeline + if _, err = pipe.Exec(ctx); err != nil { + slog.Error("Failed to execute pipeline", "error", err) + return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err) + } + // Reset pipeline and cmds + pipe = a.RedisClient.Pipeline() + cmds = cmds[:0] } - // Reset pipeline and cmds - pipe = a.RedisClient.Pipeline() - cmds = cmds[:0] } } diff --git a/pkg/aclmanager/aclmanager_test.go b/pkg/aclmanager/aclmanager_test.go index 6883768..c436126 100644 --- a/pkg/aclmanager/aclmanager_test.go +++ b/pkg/aclmanager/aclmanager_test.go @@ -234,6 +234,100 @@ func TestListAcls(t *testing.T) { } } +func TestListAndMapAcls(t *testing.T) { + t.Parallel() + tests := []struct { + name string + mockResp interface{} + expectedHashMap map[string]string + expectedStrMap map[string]string + wantErr bool + expectedErrMsg string + }{ + { + name: "valid ACL list", + mockResp: []interface{}{ + "user default on nopass ~* &* +@all", + "user alice on >password ~keys:* -@all +get +set +del", + }, + expectedHashMap: map[string]string{ + "default": hashString("user default on nopass ~* &* +@all"), + "alice": hashString("user alice on >password ~keys:* -@all +get +set +del"), + }, + expectedStrMap: map[string]string{ + "default": "user default on nopass ~* &* +@all", + "alice": "user alice on >password ~keys:* -@all +get +set +del", + }, + wantErr: false, + }, + { + name: "empty ACL list", + mockResp: []interface{}{}, + expectedHashMap: map[string]string{}, + expectedStrMap: map[string]string{}, + wantErr: false, + }, + { + name: "error from Redis client", + mockResp: nil, + wantErr: true, + expectedErrMsg: "error listing ACLs", + }, + { + name: "invalid ACL format", + mockResp: []interface{}{ + "invalid_acl", + "user alice on >password ~keys:* -@all +get +set +del", + }, + expectedHashMap: map[string]string{ + "alice": hashString("user alice on >password ~keys:* -@all +get +set +del"), + }, + expectedStrMap: map[string]string{ + "alice": "user alice on >password ~keys:* -@all +get +set +del", + }, + wantErr: false, + }, + { + name: "result is not []interface{}", + mockResp: "invalid_type", + wantErr: true, + expectedErrMsg: "unexpected result format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redisClient, mock := redismock.NewClientMock() + + if tt.wantErr && tt.mockResp == nil { + mock.ExpectDo("ACL", "LIST").SetErr(fmt.Errorf("error")) + } else { + mock.ExpectDo("ACL", "LIST").SetVal(tt.mockResp) + } + + aclHashMap, aclStrMap, err := listAndMapAcls(context.Background(), redisClient) + + if (err != nil) != tt.wantErr { + t.Errorf("listAndMapAcls() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedHashMap, aclHashMap) + assert.Equal(t, tt.expectedStrMap, aclStrMap) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + func TestSyncAcls(t *testing.T) { t.Parallel() tests := []struct {