Skip to content

Commit

Permalink
fix(agent): allow serial order for servers connection per instance (#…
Browse files Browse the repository at this point in the history
…6020)


* enforce serial order on agent connections

* add serial order for dataflow engine

* move logging after acquiring lock

* lint fixes

* add tests for chainer subscribe

* increase sleep to fix flaky test
  • Loading branch information
sakoush authored Nov 1, 2024
1 parent 659d4ba commit 520ba61
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 12 deletions.
14 changes: 12 additions & 2 deletions scheduler/pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type Server struct {
certificateStore *seldontls.CertificateStore
waiter *modelRelocatedWaiter // waiter for when we want to drain a particular server replica
autoscalingServiceEnabled bool
agentMutex sync.Map // to force a serial order per agent (serverName, replicaIdx)
}

type SchedulerAgent interface {
Expand All @@ -138,6 +139,7 @@ func NewAgentServer(
scheduler: scheduler,
waiter: newModelRelocatedWaiter(),
autoscalingServiceEnabled: autoscalingServiceEnabled,
agentMutex: sync.Map{},
}

hub.RegisterModelEventHandler(
Expand Down Expand Up @@ -383,12 +385,20 @@ func (s *Server) ModelScalingTrigger(stream pb.AgentService_ModelScalingTriggerS

func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentService_SubscribeServer) error {
logger := s.logger.WithField("func", "Subscribe")
key := ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}

// this is forcing a serial order per agent (serverName, replicaIdx)
// in general this will make sure that a given agent disconnects fully before another agent is allowed to connect
mu, _ := s.agentMutex.LoadOrStore(key, &sync.Mutex{})
mu.(*sync.Mutex).Lock()
defer mu.(*sync.Mutex).Unlock()

logger.Infof("Received subscribe request from %s:%d", request.ServerName, request.ReplicaIdx)

fin := make(chan bool)

s.mutex.Lock()
s.agents[ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}] = &AgentSubscriber{
s.agents[key] = &AgentSubscriber{
finished: fin,
stream: stream,
}
Expand All @@ -414,7 +424,7 @@ func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentSer
case <-ctx.Done():
logger.Infof("Client replica %s:%d has disconnected", request.ServerName, request.ReplicaIdx)
s.mutex.Lock()
delete(s.agents, ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx})
delete(s.agents, key)
s.mutex.Unlock()
s.removeServerReplicaImpl(request.GetServerName(), int(request.GetReplicaIdx())) // this is non-blocking beyond rescheduling models on removed server
return nil
Expand Down
131 changes: 129 additions & 2 deletions scheduler/pkg/agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,39 @@ the Change License after the Change Date as each is defined in accordance with t
package agent

import (
"context"
"fmt"
"testing"
"time"

. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
)

type mockScheduler struct {
}

var _ scheduler.Scheduler = (*mockScheduler)(nil)

func (s mockScheduler) Schedule(_ string) error {
return nil
}

func (s mockScheduler) ScheduleFailedModels() ([]string, error) {
return nil, nil
}

type mockStore struct {
models map[string]*store.ModelSnapshot
}
Expand Down Expand Up @@ -91,15 +109,15 @@ func (m *mockStore) UpdateModelState(modelKey string, version uint32, serverKey
}

func (m *mockStore) AddServerReplica(request *pb.AgentSubscribeRequest) error {
panic("implement me")
return nil
}

func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error {
panic("implement me")
}

func (m *mockStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) {
panic("implement me")
return nil, nil
}

func (m *mockStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) {
Expand Down Expand Up @@ -943,3 +961,112 @@ func TestAutoscalingEnabled(t *testing.T) {
}

}

func TestSubscribe(t *testing.T) {
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type ag struct {
id uint32
doClose bool
}
type test struct {
name string
agents []ag
expectedAgentsCount int
expectedAgentsCountAfterClose int
}
tests := []test{
{
name: "simple",
agents: []ag{
{1, true}, {2, true},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 0,
},
{
name: "simple - no close",
agents: []ag{
{1, true}, {2, false},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates",
agents: []ag{
{1, true}, {1, false},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates with all close",
agents: []ag{
{1, true}, {1, true}, {1, true},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 0,
},
}

getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn {
conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
grpcClient := agent.NewAgentServiceClient(conn)
_, _ = grpcClient.Subscribe(
context,
&agent.AgentSubscribeRequest{
ServerName: "dummy",
ReplicaIdx: id,
ReplicaConfig: &agent.ReplicaConfig{},
Shared: true,
AvailableMemoryBytes: 0,
},
)
return conn
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := log.New()
eventHub, err := coordinator.NewEventHub(logger)
g.Expect(err).To(BeNil())
server := NewAgentServer(logger, &mockStore{}, mockScheduler{}, eventHub, false)
port, err := testing_utils.GetFreePortForTest()
if err != nil {
t.Fatal(err)
}
err = server.startServer(uint(port), false)
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)

streams := make([]*grpc.ClientConn, 0)
for _, a := range test.agents {
go func(id uint32) {
conn := getStream(id, context.Background(), port)
streams = append(streams, conn)
}(a.id)
}

time.Sleep(500 * time.Millisecond)

g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount))

for idx, s := range streams {
go func(idx int, s *grpc.ClientConn) {
if test.agents[idx].doClose {
s.Close()
}
}(idx, s)
}

time.Sleep(10 * time.Second)

g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose))

server.StopAgentStreams()
})
}
}
24 changes: 17 additions & 7 deletions scheduler/pkg/kafka/dataflow/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type ChainerServer struct {
pipelineHandler pipeline.PipelineHandler
topicNamer *kafka.TopicNamer
loadBalancer util.LoadBalancer
chainerMutex sync.Map
chainer.UnimplementedChainerServer
}

Expand All @@ -66,6 +67,7 @@ func NewChainerServer(logger log.FieldLogger, eventHub *coordinator.EventHub, pi
pipelineHandler: pipelineHandler,
topicNamer: topicNamer,
loadBalancer: loadBalancer,
chainerMutex: sync.Map{},
}

eventHub.RegisterPipelineEventHandler(
Expand Down Expand Up @@ -125,17 +127,25 @@ func (c *ChainerServer) PipelineUpdateEvent(ctx context.Context, message *chaine

func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscriptionRequest, stream chainer.Chainer_SubscribePipelineUpdatesServer) error {
logger := c.logger.WithField("func", "SubscribePipelineStatus")

key := req.GetName()
// this is forcing a serial order per dataflow-engine
// in general this will make sure that a given dataflow-engine disconnects fully before another dataflow-engine is allowed to connect
mu, _ := c.chainerMutex.LoadOrStore(key, &sync.Mutex{})
mu.(*sync.Mutex).Lock()
defer mu.(*sync.Mutex).Unlock()

logger.Infof("Received subscribe request from %s", req.GetName())

fin := make(chan bool)

c.mu.Lock()
c.streams[req.Name] = &ChainerSubscription{
name: req.Name,
c.streams[key] = &ChainerSubscription{
name: key,
stream: stream,
fin: fin,
}
c.loadBalancer.AddServer(req.Name)
c.loadBalancer.AddServer(key)
c.mu.Unlock()

// Handle addition of new server
Expand All @@ -148,13 +158,13 @@ func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscripti
for {
select {
case <-fin:
logger.Infof("Closing stream for %s", req.GetName())
logger.Infof("Closing stream for %s", key)
return nil
case <-ctx.Done():
logger.Infof("Stream disconnected %s", req.GetName())
logger.Infof("Stream disconnected %s", key)
c.mu.Lock()
c.loadBalancer.RemoveServer(req.Name)
delete(c.streams, req.Name)
c.loadBalancer.RemoveServer(key)
delete(c.streams, key)
c.mu.Unlock()
// Handle removal of server
c.rebalance()
Expand Down
Loading

0 comments on commit 520ba61

Please sign in to comment.