Skip to content

Commit

Permalink
Add Raft State Management for Load Balancers (#641)
Browse files Browse the repository at this point in the history
* feat(raft): add round-robin state management

- Implemented round-robin state to enhance leader election and task distribution.
- Updated Raft state machine to incorporate round-robin logic.
- Added tests to ensure correct round-robin behavior in various scenarios.

This change improves the efficiency and fairness of task handling within the Raft cluster.

* Fix race condition in server test by adding mutex for proxy state checks

- Introduced a mutex (`proxyStateMutex`) to synchronize access to proxy state checks in `testProxy` function.
- Modified the `testProxy` function to lock the mutex before checking the state of `AvailableConnections` and `busyConnections`.
- Updated the test logic to ensure that one of the proxies is in the expected state, preventing race conditions where the second goroutine could access the connection state prematurely.
- Removed the `proxy` parameter from `testProxy` function calls as it is no longer needed.

This change addresses a race condition that could cause the second goroutine to access connection states before they are properly synchronized, ensuring reliable test results.

* feat: Add weighted round-robin state to Raft

Add support for storing weighted round-robin load balancer state in Raft FSM
to ensure consistency across cluster nodes. Changes include:

- Add WeightedProxy and WeightedRRPayload structs for state management
- Store proxy weights in Raft FSM using weightedRRStates map
- Update WeightedRoundRobin to use Raft for weight tracking
- Add new CommandUpdateWeightedRR command type
- Remove local weight tracking in favor of distributed state

This change ensures that proxy weights remain consistent across cluster nodes
during failover and leader changes.

* perf(network): optimize weighted round-robin with batch updates

Improve performance of WeightedRoundRobin.NextProxy by reducing Raft operations:
- Replace multiple individual Raft updates with a single batch operation
- Introduce new CommandUpdateWeightedRRBatch command type
- Collect all proxy weight updates in memory before applying
- Reduce number of Raft.Apply calls from N+1 to 1 (where N is number of proxies)

This change significantly reduces the number of Raft consensus operations
needed for weight updates in the weighted round-robin load balancer.

* fix: improve error handling and code readability in network package

- Add JSON marshaling error handling in WeightedRoundRobin
- Simplify proxy state validation logic in server tests
- Clean up test formatting

The main changes improve error propagation and code clarity by:
- Properly handling JSON marshaling errors in NextProxy method
- Refactoring conditional logic in server tests to be more readable
- Removing unnecessary empty lines in test files

* feat: Add weighted round-robin state to FSM snapshots

Add support for persisting and restoring weighted round-robin load balancer
states in the Raft FSM snapshots. This ensures the weighted round-robin
configuration survives cluster restarts and leader changes.

Changes:
- Add weightedRRStates to FSMSnapshot struct
- Update Snapshot() to copy weighted round-robin states
- Extend Restore() and Persist() to handle weighted round-robin data

* Add grpcAddress field to raft peer configuration

Add commented example of grpcAddress field in peer configuration,
which specifies the gRPC endpoint for raft peer communication.

* feat(raft): Add comprehensive test coverage for FSM operations

Add test cases covering:
- Weighted round-robin operations (single and batch updates)
- Round-robin index management
- Invalid command handling
- FSM snapshot restoration
- Node shutdown scenarios

* refactor: introduce dedicated ApplyTimeout constant for Raft operations

Replace usage of LeaderElectionTimeout with a new dedicated ApplyTimeout constant
(2 seconds) for Raft command applications across different load balancing
strategies (ConsistentHash, RoundRobin, WeightedRoundRobin).

This change provides better separation of concerns by using a more appropriate
timeout value for command applications rather than reusing the leader election
timeout.
  • Loading branch information
sinadarbouy authored Dec 21, 2024
1 parent 32c80f7 commit 86ec724
Show file tree
Hide file tree
Showing 11 changed files with 626 additions and 105 deletions.
1 change: 1 addition & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ raft:
peers: []
# - id: node2
# address: 127.0.0.1:2223
# grpcAddress: 127.0.0.1:50052
22 changes: 13 additions & 9 deletions network/consistenthash.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net"
"strconv"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
Expand Down Expand Up @@ -69,10 +70,12 @@ func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDEr
}

// Create and apply the command through Raft
cmd := raft.ConsistentHashCommand{
Type: raft.CommandAddConsistentHashEntry,
Hash: hash,
BlockName: proxy.GetBlockName(),
cmd := raft.Command{
Type: raft.CommandAddConsistentHashEntry,
Payload: raft.ConsistentHashPayload{
Hash: hash,
BlockName: proxy.GetBlockName(),
},
}

cmdBytes, marshalErr := json.Marshal(cmd)
Expand All @@ -81,17 +84,18 @@ func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDEr
}

// Apply the command through Raft
if err := ch.server.RaftNode.Apply(cmdBytes, raft.LeaderElectionTimeout); err != nil {
if err := ch.server.RaftNode.Apply(cmdBytes, raft.ApplyTimeout); err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

return proxy, nil
}

// hashKey hashes a given key using the MurmurHash3 algorithm. It is used to generate consistent hash values
// for IP addresses or connection strings.
func hashKey(key string) uint64 {
return murmur3.Sum64([]byte(key))
// hashKey hashes a given key using the MurmurHash3 algorithm and returns it as a string. It is used to generate
// consistent hash values for IP addresses or connection strings.
func hashKey(key string) string {
hash := murmur3.Sum64([]byte(key))
return strconv.FormatUint(hash, 10)
}

// extractIPFromConn extracts the IP address from the connection's remote address. It splits the address
Expand Down
10 changes: 6 additions & 4 deletions network/consistenthash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ func TestConsistentHashNextProxyUseSourceIpExists(t *testing.T) {
// Instead of setting hashMap directly, setup the FSM
hash := hashKey("192.168.1.1" + server.GroupName)
// Create and apply the command through Raft
cmd := raft.ConsistentHashCommand{
Type: raft.CommandAddConsistentHashEntry,
Hash: hash,
BlockName: proxies[2].GetBlockName(),
cmd := raft.Command{
Type: raft.CommandAddConsistentHashEntry,
Payload: raft.ConsistentHashPayload{
Hash: hash,
BlockName: proxies[2].GetBlockName(),
},
}

cmdBytes, marshalErr := json.Marshal(cmd)
Expand Down
2 changes: 2 additions & 0 deletions network/network_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ func setupProxy(
ClientConfig: clientConfig,
Logger: logger,
PluginTimeout: config.DefaultPluginTimeout,
GroupName: "test-group",
BlockName: clientIP + ":" + clientPort,
},
)

Expand Down
39 changes: 35 additions & 4 deletions network/roundrobin.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
package network

import (
"encoding/json"
"math"
"sync/atomic"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/gatewayd-io/gatewayd/raft"
)

type RoundRobin struct {
proxies []IProxy
next atomic.Uint32
server *Server
mu sync.Mutex
}

// NewRoundRobin creates a new RoundRobin load balancer.
func NewRoundRobin(server *Server) *RoundRobin {
return &RoundRobin{proxies: server.Proxies}
return &RoundRobin{proxies: server.Proxies, server: server}
}

// NextProxy returns the next proxy in the round-robin sequence.
func (r *RoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) {
if len(r.proxies) > math.MaxUint32 {
// This should never happen, but if it does, we fall back to the first proxy.
Expand All @@ -24,6 +29,32 @@ func (r *RoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) {
return nil, gerr.ErrNoProxiesAvailable
}

nextIndex := r.next.Add(1)
r.mu.Lock()
defer r.mu.Unlock()

// Get current index from Raft FSM
currentIndex := r.server.RaftNode.Fsm.GetRoundRobinNext(r.server.GroupName)
nextIndex := currentIndex + 1

// Create Raft command
cmd := raft.Command{
Type: raft.CommandAddRoundRobinNext,
Payload: raft.RoundRobinPayload{
NextIndex: nextIndex,
GroupName: r.server.GroupName,
},
}

// Convert command to JSON
data, err := json.Marshal(cmd)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

// Apply through Raft
if err := r.server.RaftNode.Apply(data, raft.ApplyTimeout); err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

return r.proxies[nextIndex%uint32(len(r.proxies))], nil //nolint:gosec
}
59 changes: 54 additions & 5 deletions network/roundrobin_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package network

import (
"encoding/json"
"math"
"sync"
"testing"

"github.com/gatewayd-io/gatewayd/raft"
"github.com/gatewayd-io/gatewayd/testhelpers"
"github.com/stretchr/testify/require"
)

// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes
Expand All @@ -25,12 +30,21 @@ func TestNewRoundRobin(t *testing.T) {
// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure
// that it returns proxies in the expected order.
func TestRoundRobin_NextProxy(t *testing.T) {
raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{Proxies: proxies}
server := &Server{Proxies: proxies, RaftNode: raftHelper.Node, GroupName: "test-group"}
roundRobin := NewRoundRobin(server)

expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"}
Expand Down Expand Up @@ -58,7 +72,18 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{Proxies: proxies}

raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
server := &Server{Proxies: proxies, RaftNode: raftHelper.Node, GroupName: "test-group"}
server.initializeProxies()
roundRobin := NewRoundRobin(server)

var waitGroup sync.WaitGroup
Expand All @@ -73,7 +98,7 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
}

waitGroup.Wait()
nextIndex := roundRobin.next.Load()
nextIndex := server.RaftNode.Fsm.GetRoundRobinNext(server.GroupName)
if nextIndex != uint32(numGoroutines) {
t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex)
}
Expand All @@ -84,18 +109,42 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
// uint32 value and ensures that the proxy selection wraps around as expected when the
// counter overflows.
func TestNextProxyOverflow(t *testing.T) {
raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()
// Create a server with a few mock proxies
server := &Server{
Proxies: []IProxy{
&MockProxy{},
&MockProxy{},
&MockProxy{},
},
GroupName: "test-group",
RaftNode: raftHelper.Node,
}
server.initializeProxies()

roundRobin := NewRoundRobin(server)

// Set the next value to near the max uint32 value to force an overflow
roundRobin.next.Store(math.MaxUint32 - 1)
cmd := raft.Command{
Type: raft.CommandAddRoundRobinNext,
Payload: raft.RoundRobinPayload{
NextIndex: math.MaxUint32 - 1,
GroupName: server.GroupName,
},
}
// Convert command to JSON
data, err := json.Marshal(cmd)
require.NoError(t, err)

require.NoError(t, raftHelper.Node.Apply(data, raft.ApplyTimeout))

// Call NextProxy multiple times to trigger the overflow
for range 4 {
Expand All @@ -110,7 +159,7 @@ func TestNextProxyOverflow(t *testing.T) {

// After overflow, next value should wrap around
expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2
actualNextValue := roundRobin.next.Load()
actualNextValue := server.RaftNode.Fsm.GetRoundRobinNext(server.GroupName)
if actualNextValue != expectedNextValue {
t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue)
}
Expand Down
31 changes: 26 additions & 5 deletions network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ func TestRunServer(t *testing.T) {
proxy1 := setupProxy(ctx, t, postgresHostIP1, postgresMappedPort1.Port(), logger, pluginRegistry)
proxy2 := setupProxy(ctx, t, postgresHostIP2, postgresMappedPort2.Port(), logger, pluginRegistry)

raftHelper, err := testhelpers.NewTestRaftNode(t)
if err != nil {
t.Fatalf("Failed to create test raft node: %v", err)
}
defer func() {
if err := raftHelper.Cleanup(); err != nil {
t.Errorf("Failed to cleanup raft: %v", err)
}
}()

// Create a server.
server := NewServer(
ctx,
Expand All @@ -94,6 +104,8 @@ func TestRunServer(t *testing.T) {
PluginTimeout: config.DefaultPluginTimeout,
HandshakeTimeout: config.DefaultHandshakeTimeout,
LoadbalancerStrategyName: config.RoundRobinStrategy,
GroupName: "test-group",
RaftNode: raftHelper.Node,
},
)
assert.NotNil(t, server)
Expand All @@ -114,9 +126,10 @@ func TestRunServer(t *testing.T) {
}
}(t, server)

var proxyStateMutex sync.Mutex

testProxy := func(
t *testing.T,
proxy *Proxy,
waitGroup *sync.WaitGroup,
) {
t.Helper()
Expand Down Expand Up @@ -163,8 +176,16 @@ func TestRunServer(t *testing.T) {
// AuthenticationOk.
assert.Equal(t, uint8(0x52), data[0])

assert.Equal(t, 2, proxy.AvailableConnections.Size())
assert.Equal(t, 1, proxy.busyConnections.Size())
// Lock the mutex before checking the proxy states
proxyStateMutex.Lock()
defer proxyStateMutex.Unlock()

// Check that one of the proxies has the expected state.
proxyInExpectedState := (proxy1.AvailableConnections.Size() == 2 && proxy1.busyConnections.Size() == 1) ||
(proxy2.AvailableConnections.Size() == 2 && proxy2.busyConnections.Size() == 1)
if !proxyInExpectedState {
t.Errorf("Neither proxy is in the expected state")
}

// Terminate the connection.
sent, err = client.Send(CreatePgTerminatePacket())
Expand All @@ -180,8 +201,8 @@ func TestRunServer(t *testing.T) {
// Test both proxies.
// Based on the default Loadbalancer strategy (RoundRobin), the first client request will be sent to proxy2,
// followed by proxy1 for the next request.
go testProxy(t, proxy2, &waitGroup)
go testProxy(t, proxy1, &waitGroup)
go testProxy(t, &waitGroup)
go testProxy(t, &waitGroup)

// Wait for all goroutines.
waitGroup.Wait()
Expand Down
Loading

0 comments on commit 86ec724

Please sign in to comment.