Skip to content

Commit

Permalink
adds a few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ncode committed Nov 27, 2023
1 parent 8ab9e81 commit 9b035a2
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 30 deletions.
67 changes: 37 additions & 30 deletions pkg/aclmanager/aclmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,37 +109,9 @@ func (a *AclManager) SyncAcls() (err error) {
})
defer master.Close()

masterAcls, err := listAcls(master)
_, err := syncAcls(master, a.RedisClient)
if err != nil {
return fmt.Errorf("error listing master acls: %v", err)
}

currentAcls, err := listAcls(a.RedisClient)
if err != nil {
return fmt.Errorf("error listing current acls: %v", err)
}

aclsToSync := []string{}
for _, acl := range currentAcls {
if slices.Contains(masterAcls, acl) {
continue
}
err = a.RedisClient.Do(context.Background(), "ACL", "DELUSER", acl).Err()
if err != nil {
return fmt.Errorf("error deleting acl: %v", err)
}
aclsToSync = append(aclsToSync, acl)
}

for _, acl := range aclsToSync {
err = a.RedisClient.Do(context.Background(), "ACL", "SETUSER", acl).Err()
if err != nil {
return fmt.Errorf("error setting acl: %v", err)
}
}

if err != nil {
return err
return fmt.Errorf("error syncing acls: %v", err)
}
}
}
Expand Down Expand Up @@ -180,3 +152,38 @@ func listAcls(client *redis.Client) (acls []string, err error) {

return acls, err
}

// syncAcls returns a list of acls in the cluster based on the redis acl list command
func syncAcls(source *redis.Client, destination *redis.Client) (deleted []string, err error) {
aclsToSync, err := listAcls(source)
if err != nil {
return deleted, fmt.Errorf("error listing master acls: %v", err)
}

currentAcls, err := listAcls(destination)
if err != nil {
return deleted, fmt.Errorf("error listing current acls: %v", err)
}

for _, acl := range currentAcls {
acl := acl
if pos := slices.Index(aclsToSync, acl); pos != -1 {
aclsToSync = slices.Delete(aclsToSync, pos, pos+1)
continue
}
err = destination.Do(context.Background(), "ACL", "DELUSER", acl).Err()
if err != nil {
return deleted, fmt.Errorf("error deleting acl: %v", err)
}
deleted = append(deleted, acl)
}

for _, acl := range aclsToSync {
err = destination.Do(context.Background(), "ACL", "SETUSER", acl).Err()
if err != nil {
return deleted, fmt.Errorf("error setting acl: %v", err)
}
}

return deleted, err
}
73 changes: 73 additions & 0 deletions pkg/aclmanager/aclmanager_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package aclmanager

import (
"fmt"
"reflect"
"testing"

"github.com/go-redis/redismock/v9"
Expand Down Expand Up @@ -160,3 +162,74 @@ func TestListAcls(t *testing.T) {
})
}
}

func TestSyncAcls(t *testing.T) {
tests := []struct {
name string
sourceAcls []interface{}
destinationAcls []interface{}
expectedDeleted []string
expectedAdded []string
listAclsError error
redisDoError error
wantErr bool
}{
{
name: "ACLs synced with deletions",
sourceAcls: []interface{}{"acl1", "acl2"},
destinationAcls: []interface{}{"acl1", "acl3"},
expectedDeleted: []string{"acl3"},
expectedAdded: []string{"acl2"},
wantErr: false,
},
{
name: "No ACLs to delete",
sourceAcls: []interface{}{"acl1", "acl2"},
destinationAcls: []interface{}{"acl1", "acl2"},
expectedDeleted: nil,
wantErr: false,
},
{
name: "Error listing source ACLs",
listAclsError: fmt.Errorf("error listing source ACLs"),
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sourceClient, sourceMock := redismock.NewClientMock()
destinationClient, destMock := redismock.NewClientMock()

if tt.listAclsError != nil {
sourceMock.ExpectDo("ACL", "LIST").SetErr(tt.listAclsError)
} else {
sourceMock.ExpectDo("ACL", "LIST").SetVal(tt.sourceAcls)
}

if tt.listAclsError != nil {
destMock.ExpectDo("ACL", "LIST").SetErr(tt.listAclsError)
} else {
destMock.ExpectDo("ACL", "LIST").SetVal(tt.destinationAcls)
if tt.expectedDeleted != nil {
for _, acl := range tt.expectedDeleted {
destMock.ExpectDo("ACL", "DELUSER", acl).SetVal("OK")
}
}
if tt.expectedAdded != nil {
for _, acl := range tt.expectedAdded {
destMock.ExpectDo("ACL", "SETUSER", acl).SetVal("OK")
}
}
}

deleted, err := syncAcls(sourceClient, destinationClient)
if (err != nil) != tt.wantErr {
t.Errorf("syncAcls() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(deleted, tt.expectedDeleted) {
t.Errorf("syncAcls() deleted = %v, expectedDeleted %v", deleted, tt.expectedDeleted)
}
})
}
}

0 comments on commit 9b035a2

Please sign in to comment.