Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Raft State Management for Load Balancers #641

Merged
merged 9 commits into from
Dec 21, 2024
1 change: 1 addition & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ raft:
peers: []
# - id: node2
# address: 127.0.0.1:2223
# grpcAddress: 127.0.0.1:50052
22 changes: 13 additions & 9 deletions network/consistenthash.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net"
"strconv"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
Expand Down Expand Up @@ -69,10 +70,12 @@ func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDEr
}

// Create and apply the command through Raft
cmd := raft.ConsistentHashCommand{
Type: raft.CommandAddConsistentHashEntry,
Hash: hash,
BlockName: proxy.GetBlockName(),
cmd := raft.Command{
Type: raft.CommandAddConsistentHashEntry,
Payload: raft.ConsistentHashPayload{
Hash: hash,
BlockName: proxy.GetBlockName(),
},
}

cmdBytes, marshalErr := json.Marshal(cmd)
Expand All @@ -81,17 +84,18 @@ func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDEr
}

// Apply the command through Raft
if err := ch.server.RaftNode.Apply(cmdBytes, raft.LeaderElectionTimeout); err != nil {
if err := ch.server.RaftNode.Apply(cmdBytes, raft.ApplyTimeout); err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

return proxy, nil
}

// hashKey hashes a given key using the MurmurHash3 algorithm. It is used to generate consistent hash values
// for IP addresses or connection strings.
func hashKey(key string) uint64 {
return murmur3.Sum64([]byte(key))
// hashKey hashes a given key using the MurmurHash3 algorithm and returns it as a string. It is used to generate
// consistent hash values for IP addresses or connection strings.
func hashKey(key string) string {
hash := murmur3.Sum64([]byte(key))
return strconv.FormatUint(hash, 10)
}

// extractIPFromConn extracts the IP address from the connection's remote address. It splits the address
Expand Down
10 changes: 6 additions & 4 deletions network/consistenthash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ func TestConsistentHashNextProxyUseSourceIpExists(t *testing.T) {
// Instead of setting hashMap directly, setup the FSM
hash := hashKey("192.168.1.1" + server.GroupName)
// Create and apply the command through Raft
cmd := raft.ConsistentHashCommand{
Type: raft.CommandAddConsistentHashEntry,
Hash: hash,
BlockName: proxies[2].GetBlockName(),
cmd := raft.Command{
Type: raft.CommandAddConsistentHashEntry,
Payload: raft.ConsistentHashPayload{
Hash: hash,
BlockName: proxies[2].GetBlockName(),
},
}

cmdBytes, marshalErr := json.Marshal(cmd)
Expand Down
2 changes: 2 additions & 0 deletions network/network_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ func setupProxy(
ClientConfig: clientConfig,
Logger: logger,
PluginTimeout: config.DefaultPluginTimeout,
GroupName: "test-group",
BlockName: clientIP + ":" + clientPort,
},
)

Expand Down
39 changes: 35 additions & 4 deletions network/roundrobin.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
package network

import (
"encoding/json"
"math"
"sync/atomic"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/gatewayd-io/gatewayd/raft"
)

type RoundRobin struct {
proxies []IProxy
next atomic.Uint32
server *Server
mu sync.Mutex
}

// NewRoundRobin creates a new RoundRobin load balancer.
func NewRoundRobin(server *Server) *RoundRobin {
return &RoundRobin{proxies: server.Proxies}
return &RoundRobin{proxies: server.Proxies, server: server}
}

// NextProxy returns the next proxy in the round-robin sequence.
func (r *RoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) {
if len(r.proxies) > math.MaxUint32 {
// This should never happen, but if it does, we fall back to the first proxy.
Expand All @@ -24,6 +29,32 @@ func (r *RoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) {
return nil, gerr.ErrNoProxiesAvailable
}

nextIndex := r.next.Add(1)
r.mu.Lock()
defer r.mu.Unlock()

// Get current index from Raft FSM
currentIndex := r.server.RaftNode.Fsm.GetRoundRobinNext(r.server.GroupName)
nextIndex := currentIndex + 1

// Create Raft command
cmd := raft.Command{
Type: raft.CommandAddRoundRobinNext,
Payload: raft.RoundRobinPayload{
NextIndex: nextIndex,
GroupName: r.server.GroupName,
},
}

// Convert command to JSON
data, err := json.Marshal(cmd)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

// Apply through Raft
if err := r.server.RaftNode.Apply(data, raft.ApplyTimeout); err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

return r.proxies[nextIndex%uint32(len(r.proxies))], nil //nolint:gosec
}
59 changes: 54 additions & 5 deletions network/roundrobin_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package network

import (
"encoding/json"
"math"
"sync"
"testing"

"github.com/gatewayd-io/gatewayd/raft"
"github.com/gatewayd-io/gatewayd/testhelpers"
"github.com/stretchr/testify/require"
)

// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes
Expand All @@ -25,12 +30,21 @@ func TestNewRoundRobin(t *testing.T) {
// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure
// that it returns proxies in the expected order.
func TestRoundRobin_NextProxy(t *testing.T) {
raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{Proxies: proxies}
server := &Server{Proxies: proxies, RaftNode: raftHelper.Node, GroupName: "test-group"}
roundRobin := NewRoundRobin(server)

expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"}
Expand Down Expand Up @@ -58,7 +72,18 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{Proxies: proxies}

raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
server := &Server{Proxies: proxies, RaftNode: raftHelper.Node, GroupName: "test-group"}
server.initializeProxies()
roundRobin := NewRoundRobin(server)

var waitGroup sync.WaitGroup
Expand All @@ -73,7 +98,7 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
}

waitGroup.Wait()
nextIndex := roundRobin.next.Load()
nextIndex := server.RaftNode.Fsm.GetRoundRobinNext(server.GroupName)
if nextIndex != uint32(numGoroutines) {
t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex)
}
Expand All @@ -84,18 +109,42 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
// uint32 value and ensures that the proxy selection wraps around as expected when the
// counter overflows.
func TestNextProxyOverflow(t *testing.T) {
raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
// Create a server with a few mock proxies
server := &Server{
Proxies: []IProxy{
&MockProxy{},
&MockProxy{},
&MockProxy{},
},
GroupName: "test-group",
RaftNode: raftHelper.Node,
}
server.initializeProxies()

roundRobin := NewRoundRobin(server)

// Set the next value to near the max uint32 value to force an overflow
roundRobin.next.Store(math.MaxUint32 - 1)
cmd := raft.Command{
Type: raft.CommandAddRoundRobinNext,
Payload: raft.RoundRobinPayload{
NextIndex: math.MaxUint32 - 1,
GroupName: server.GroupName,
},
}
// Convert command to JSON
data, err := json.Marshal(cmd)
require.NoError(t, err)

require.NoError(t, raftHelper.Node.Apply(data, raft.ApplyTimeout))

// Call NextProxy multiple times to trigger the overflow
for range 4 {
Expand All @@ -110,7 +159,7 @@ func TestNextProxyOverflow(t *testing.T) {

// After overflow, next value should wrap around
expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2
actualNextValue := roundRobin.next.Load()
actualNextValue := server.RaftNode.Fsm.GetRoundRobinNext(server.GroupName)
if actualNextValue != expectedNextValue {
t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue)
}
Expand Down
31 changes: 26 additions & 5 deletions network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ func TestRunServer(t *testing.T) {
proxy1 := setupProxy(ctx, t, postgresHostIP1, postgresMappedPort1.Port(), logger, pluginRegistry)
proxy2 := setupProxy(ctx, t, postgresHostIP2, postgresMappedPort2.Port(), logger, pluginRegistry)

raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()

// Create a server.
server := NewServer(
ctx,
Expand All @@ -94,6 +104,8 @@ func TestRunServer(t *testing.T) {
PluginTimeout: config.DefaultPluginTimeout,
HandshakeTimeout: config.DefaultHandshakeTimeout,
LoadbalancerStrategyName: config.RoundRobinStrategy,
GroupName: "test-group",
RaftNode: raftHelper.Node,
},
)
assert.NotNil(t, server)
Expand All @@ -114,9 +126,10 @@ func TestRunServer(t *testing.T) {
}
}(t, server)

var proxyStateMutex sync.Mutex

testProxy := func(
t *testing.T,
proxy *Proxy,
waitGroup *sync.WaitGroup,
) {
t.Helper()
Expand Down Expand Up @@ -163,8 +176,16 @@ func TestRunServer(t *testing.T) {
// AuthenticationOk.
assert.Equal(t, uint8(0x52), data[0])

assert.Equal(t, 2, proxy.AvailableConnections.Size())
assert.Equal(t, 1, proxy.busyConnections.Size())
// Lock the mutex before checking the proxy states
proxyStateMutex.Lock()
defer proxyStateMutex.Unlock()

// Check that one of the proxies has the expected state.
proxyInExpectedState := (proxy1.AvailableConnections.Size() == 2 && proxy1.busyConnections.Size() == 1) ||
(proxy2.AvailableConnections.Size() == 2 && proxy2.busyConnections.Size() == 1)
if !proxyInExpectedState {
t.Errorf("Neither proxy is in the expected state")
}

// Terminate the connection.
sent, err = client.Send(CreatePgTerminatePacket())
Expand All @@ -180,8 +201,8 @@ func TestRunServer(t *testing.T) {
// Test both proxies.
// Based on the default Loadbalancer strategy (RoundRobin), the first client request will be sent to proxy2,
// followed by proxy1 for the next request.
go testProxy(t, proxy2, &waitGroup)
go testProxy(t, proxy1, &waitGroup)
go testProxy(t, &waitGroup)
go testProxy(t, &waitGroup)

// Wait for all goroutines.
waitGroup.Wait()
Expand Down
Loading
Loading