diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt index 1e5485aaea..94a5e393a5 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt @@ -11,6 +11,7 @@ package io.seldon.dataflow import com.natpryce.konfig.CommandLineOption import com.natpryce.konfig.Configuration +import com.natpryce.konfig.ConfigurationMap import com.natpryce.konfig.ConfigurationProperties import com.natpryce.konfig.EnvironmentVariables import com.natpryce.konfig.Key @@ -25,6 +26,8 @@ import io.klogging.Level import io.klogging.noCoLogger import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms import io.seldon.dataflow.kafka.security.KafkaSecurityProtocols +import java.net.InetAddress +import java.util.UUID object Cli { private const val ENV_VAR_PREFIX = "SELDON_" @@ -34,6 +37,7 @@ object Cli { val logLevelApplication = Key("log.level.app", enumType(*Level.values())) val logLevelKafka = Key("log.level.kafka", enumType(*Level.values())) val namespace = Key("pod.namespace", stringType) + val dataflowReplicaId = Key("dataflow.replica.id", stringType) // Seldon components val upstreamHost = Key("upstream.host", stringType) @@ -75,6 +79,7 @@ object Cli { logLevelApplication, logLevelKafka, namespace, + dataflowReplicaId, upstreamHost, upstreamPort, kafkaBootstrapServers, @@ -105,10 +110,26 @@ object Cli { fun configWith(rawArgs: Array): Configuration { val fromProperties = ConfigurationProperties.fromResource("local.properties") + val fromSystem = getSystemConfig() val fromEnv = EnvironmentVariables(prefix = ENV_VAR_PREFIX) val fromArgs = parseArguments(rawArgs) - return fromArgs overriding fromEnv overriding fromProperties + return fromArgs overriding fromEnv overriding fromSystem overriding fromProperties + } + + private fun getSystemConfig(): Configuration { + val dataflowIdPair = this.dataflowReplicaId to getNewDataflowId() + return ConfigurationMap(dataflowIdPair) + } + + fun getNewDataflowId(assignRandomUuid: Boolean = false): String { + if (!assignRandomUuid) { + try { + return InetAddress.getLocalHost().hostName + } catch (_: Exception) { + } + } + return "seldon-dataflow-engine-" + UUID.randomUUID().toString() } private fun parseArguments(rawArgs: Array): Configuration { diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt index 8d4a899eaa..b064a974f5 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt @@ -102,9 +102,11 @@ object Main { describeRetries = config[Cli.topicDescribeRetries], describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis], ) + val subscriberId = config[Cli.dataflowReplicaId] + val subscriber = PipelineSubscriber( - "seldon-dataflow-engine", + subscriberId, kafkaProperties, kafkaAdminProperties, kafkaStreamsParams, diff --git a/scheduler/data-flow/src/main/resources/local.properties b/scheduler/data-flow/src/main/resources/local.properties index 46a7218380..68b3bd408d 100644 --- a/scheduler/data-flow/src/main/resources/local.properties +++ b/scheduler/data-flow/src/main/resources/local.properties @@ -1,5 +1,6 @@ log.level.app=INFO log.level.kafka=WARN +dataflow.replica.id=seldon-dataflow-engine kafka.bootstrap.servers=localhost:9092 kafka.consumer.prefix= kafka.security.protocol=PLAINTEXT diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt index 9011ff3d4c..52a97fa4b5 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt @@ -16,9 +16,15 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.Arguments.arguments import org.junit.jupiter.params.provider.MethodSource import strikt.api.expectCatching +import strikt.api.expectThat +import strikt.assertions.hasLength import strikt.assertions.isEqualTo +import strikt.assertions.isNotEqualTo import strikt.assertions.isSuccess +import strikt.assertions.startsWith +import java.util.UUID import java.util.stream.Stream +import kotlin.test.Test internal class CliTest { @DisplayName("Passing auth mechanism via cli argument") @@ -36,6 +42,31 @@ internal class CliTest { .isEqualTo(expectedMechanism) } + @Test + fun `should handle dataflow replica id`() { + val cliDefault = Cli.configWith(arrayOf()) + val testReplicaId = "dataflow-id-1" + val cli = Cli.configWith(arrayOf("--dataflow-replica-id", testReplicaId)) + + expectThat(cliDefault[Cli.dataflowReplicaId]) { + isNotEqualTo("seldon-dataflow-engine") + } + expectThat(cli[Cli.dataflowReplicaId]) { + isEqualTo(testReplicaId) + } + + // test random Uuid (v4) + val expectedReplicaIdPrefix = "seldon-dataflow-engine-" + val uuidStringLength = 36 + val randomReplicaUuid = Cli.getNewDataflowId(true) + expectThat(randomReplicaUuid) { + startsWith(expectedReplicaIdPrefix) + hasLength(expectedReplicaIdPrefix.length + uuidStringLength) + } + expectCatching { UUID.fromString(randomReplicaUuid.removePrefix(expectedReplicaIdPrefix)) } + .isSuccess() + } + companion object { @JvmStatic private fun saslMechanisms(): Stream { diff --git a/scheduler/pkg/agent/server.go b/scheduler/pkg/agent/server.go index 4cdb6ccaec..6744916ee9 100644 --- a/scheduler/pkg/agent/server.go +++ b/scheduler/pkg/agent/server.go @@ -112,6 +112,7 @@ type Server struct { certificateStore *seldontls.CertificateStore waiter *modelRelocatedWaiter // waiter for when we want to drain a particular server replica autoscalingServiceEnabled bool + agentMutex sync.Map // to force a serial order per agent (serverName, replicaIdx) } type SchedulerAgent interface { @@ -138,6 +139,7 @@ func NewAgentServer( scheduler: scheduler, waiter: newModelRelocatedWaiter(), autoscalingServiceEnabled: autoscalingServiceEnabled, + agentMutex: sync.Map{}, } hub.RegisterModelEventHandler( @@ -383,12 +385,20 @@ func (s *Server) ModelScalingTrigger(stream pb.AgentService_ModelScalingTriggerS func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentService_SubscribeServer) error { logger := s.logger.WithField("func", "Subscribe") + key := ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx} + + // this is forcing a serial order per agent (serverName, replicaIdx) + // in general this will make sure that a given agent disconnects fully before another agent is allowed to connect + mu, _ := s.agentMutex.LoadOrStore(key, &sync.Mutex{}) + mu.(*sync.Mutex).Lock() + defer mu.(*sync.Mutex).Unlock() + logger.Infof("Received subscribe request from %s:%d", request.ServerName, request.ReplicaIdx) fin := make(chan bool) s.mutex.Lock() - s.agents[ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}] = &AgentSubscriber{ + s.agents[key] = &AgentSubscriber{ finished: fin, stream: stream, } @@ -414,7 +424,7 @@ func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentSer case <-ctx.Done(): logger.Infof("Client replica %s:%d has disconnected", request.ServerName, request.ReplicaIdx) s.mutex.Lock() - delete(s.agents, ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}) + delete(s.agents, key) s.mutex.Unlock() s.removeServerReplicaImpl(request.GetServerName(), int(request.GetReplicaIdx())) // this is non-blocking beyond rescheduling models on removed server return nil diff --git a/scheduler/pkg/agent/server_test.go b/scheduler/pkg/agent/server_test.go index 088c6e6237..75b6adf56b 100644 --- a/scheduler/pkg/agent/server_test.go +++ b/scheduler/pkg/agent/server_test.go @@ -10,21 +10,40 @@ the Change License after the Change Date as each is defined in accordance with t package agent import ( + "context" "fmt" + "sync" "testing" "time" . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent" pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent" pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" + testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" ) +type mockScheduler struct { +} + +var _ scheduler.Scheduler = (*mockScheduler)(nil) + +func (s mockScheduler) Schedule(_ string) error { + return nil +} + +func (s mockScheduler) ScheduleFailedModels() ([]string, error) { + return nil, nil +} + type mockStore struct { models map[string]*store.ModelSnapshot } @@ -91,7 +110,7 @@ func (m *mockStore) UpdateModelState(modelKey string, version uint32, serverKey } func (m *mockStore) AddServerReplica(request *pb.AgentSubscribeRequest) error { - panic("implement me") + return nil } func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error { @@ -99,7 +118,7 @@ func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error { } func (m *mockStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) { - panic("implement me") + return nil, nil } func (m *mockStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) { @@ -943,3 +962,122 @@ func TestAutoscalingEnabled(t *testing.T) { } } + +func TestSubscribe(t *testing.T) { + log.SetLevel(log.DebugLevel) + g := NewGomegaWithT(t) + + type ag struct { + id uint32 + doClose bool + } + type test struct { + name string + agents []ag + expectedAgentsCount int + expectedAgentsCountAfterClose int + } + tests := []test{ + { + name: "simple", + agents: []ag{ + {1, true}, {2, true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 0, + }, + { + name: "simple - no close", + agents: []ag{ + {1, true}, {2, false}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 1, + }, + { + name: "duplicates", + agents: []ag{ + {1, true}, {1, false}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 1, + }, + { + name: "duplicates with all close", + agents: []ag{ + {1, true}, {1, true}, {1, true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + } + + getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn { + conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + grpcClient := agent.NewAgentServiceClient(conn) + _, _ = grpcClient.Subscribe( + context, + &agent.AgentSubscribeRequest{ + ServerName: "dummy", + ReplicaIdx: id, + ReplicaConfig: &agent.ReplicaConfig{}, + Shared: true, + AvailableMemoryBytes: 0, + }, + ) + return conn + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + logger := log.New() + eventHub, err := coordinator.NewEventHub(logger) + g.Expect(err).To(BeNil()) + server := NewAgentServer(logger, &mockStore{}, mockScheduler{}, eventHub, false) + port, err := testing_utils.GetFreePortForTest() + if err != nil { + t.Fatal(err) + } + err = server.startServer(uint(port), false) + if err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + + mu := sync.Mutex{} + streams := make([]*grpc.ClientConn, 0) + for _, a := range test.agents { + go func(id uint32) { + conn := getStream(id, context.Background(), port) + mu.Lock() + streams = append(streams, conn) + mu.Unlock() + }(a.id) + } + + maxCount := 10 + count := 0 + for len(server.agents) != test.expectedAgentsCount && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } + g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount)) + + for idx, s := range streams { + go func(idx int, s *grpc.ClientConn) { + if test.agents[idx].doClose { + s.Close() + } + }(idx, s) + } + + count = 0 + for len(server.agents) != test.expectedAgentsCountAfterClose && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } + g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose)) + + server.StopAgentStreams() + }) + } +} diff --git a/scheduler/pkg/kafka/dataflow/server.go b/scheduler/pkg/kafka/dataflow/server.go index cdb2ce1969..aeed2ce299 100644 --- a/scheduler/pkg/kafka/dataflow/server.go +++ b/scheduler/pkg/kafka/dataflow/server.go @@ -44,6 +44,7 @@ type ChainerServer struct { pipelineHandler pipeline.PipelineHandler topicNamer *kafka.TopicNamer loadBalancer util.LoadBalancer + chainerMutex sync.Map chainer.UnimplementedChainerServer } @@ -66,6 +67,7 @@ func NewChainerServer(logger log.FieldLogger, eventHub *coordinator.EventHub, pi pipelineHandler: pipelineHandler, topicNamer: topicNamer, loadBalancer: loadBalancer, + chainerMutex: sync.Map{}, } eventHub.RegisterPipelineEventHandler( @@ -125,17 +127,25 @@ func (c *ChainerServer) PipelineUpdateEvent(ctx context.Context, message *chaine func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscriptionRequest, stream chainer.Chainer_SubscribePipelineUpdatesServer) error { logger := c.logger.WithField("func", "SubscribePipelineStatus") + + key := req.GetName() + // this is forcing a serial order per dataflow-engine + // in general this will make sure that a given dataflow-engine disconnects fully before another dataflow-engine is allowed to connect + mu, _ := c.chainerMutex.LoadOrStore(key, &sync.Mutex{}) + mu.(*sync.Mutex).Lock() + defer mu.(*sync.Mutex).Unlock() + logger.Infof("Received subscribe request from %s", req.GetName()) fin := make(chan bool) c.mu.Lock() - c.streams[req.Name] = &ChainerSubscription{ - name: req.Name, + c.streams[key] = &ChainerSubscription{ + name: key, stream: stream, fin: fin, } - c.loadBalancer.AddServer(req.Name) + c.loadBalancer.AddServer(key) c.mu.Unlock() // Handle addition of new server @@ -148,13 +158,13 @@ func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscripti for { select { case <-fin: - logger.Infof("Closing stream for %s", req.GetName()) + logger.Infof("Closing stream for %s", key) return nil case <-ctx.Done(): - logger.Infof("Stream disconnected %s", req.GetName()) + logger.Infof("Stream disconnected %s", key) c.mu.Lock() - c.loadBalancer.RemoveServer(req.Name) - delete(c.streams, req.Name) + c.loadBalancer.RemoveServer(key) + delete(c.streams, key) c.mu.Unlock() // Handle removal of server c.rebalance() diff --git a/scheduler/pkg/kafka/dataflow/server_test.go b/scheduler/pkg/kafka/dataflow/server_test.go index d7f0860f4b..e52780a9a1 100644 --- a/scheduler/pkg/kafka/dataflow/server_test.go +++ b/scheduler/pkg/kafka/dataflow/server_test.go @@ -10,19 +10,23 @@ the Change License after the Change Date as each is defined in accordance with t package dataflow import ( + "context" "fmt" "os" + "sync" "testing" "time" . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "github.com/seldonio/seldon-core/apis/go/v2/mlops/chainer" "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" + testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka" "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka/config" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" @@ -321,7 +325,7 @@ func TestPipelineRollingUpgradeEvents(t *testing.T) { } // to allow events to propagate - time.Sleep(500 * time.Millisecond) + time.Sleep(700 * time.Millisecond) if test.connection { if test.loadReqV2 != nil { @@ -634,6 +638,136 @@ func TestPipelineRebalance(t *testing.T) { } } +func TestPipelineSubscribe(t *testing.T) { + g := NewGomegaWithT(t) + type ag struct { + id uint32 + doClose bool + } + + type test struct { + name string + agents []ag + expectedAgentsCount int + expectedAgentsCountAfterClose int + } + + tests := []test{ + { + name: "single connection", + agents: []ag{ + {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - one not closed", + agents: []ag{ + {id: 1, doClose: false}, {id: 2, doClose: true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 1, + }, + { + name: "multiple connection - not closed", + agents: []ag{ + {id: 1, doClose: false}, {id: 2, doClose: false}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 2, + }, + { + name: "multiple connection - closed", + agents: []ag{ + {id: 1, doClose: true}, {id: 2, doClose: true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - duplicate", + agents: []ag{ + {id: 1, doClose: true}, {id: 1, doClose: true}, {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - duplicate not closed", + agents: []ag{ + {id: 1, doClose: true}, {id: 1, doClose: false}, {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 1, + }, + } + + getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn { + conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + grpcClient := chainer.NewChainerClient(conn) + _, _ = grpcClient.SubscribePipelineUpdates( + context, + &chainer.PipelineSubscriptionRequest{ + Name: fmt.Sprintf("agent-%d", id), + }, + ) + return conn + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serverName := "dummy" + s, _ := createTestScheduler(t, serverName) + port, err := testing_utils.GetFreePortForTest() + if err != nil { + t.Fatal(err) + } + go func() { + _ = s.StartGrpcServer(uint(port)) + }() + + time.Sleep(100 * time.Millisecond) + + mu := sync.Mutex{} + streams := make([]*grpc.ClientConn, 0) + for _, a := range test.agents { + go func(id uint32) { + conn := getStream(id, context.Background(), port) + mu.Lock() + streams = append(streams, conn) + mu.Unlock() + }(a.id) + } + + maxCount := 10 + count := 0 + for len(s.streams) != test.expectedAgentsCount && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } + g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCount)) + + for idx, s := range streams { + go func(idx int, s *grpc.ClientConn) { + if test.agents[idx].doClose { + s.Close() + } + }(idx, s) + } + + count = 0 + for len(s.streams) != test.expectedAgentsCountAfterClose && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } + g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCountAfterClose)) + + s.StopSendPipelineEvents() + }) + } +} + type stubChainerServer struct { msgs chan *chainer.PipelineUpdateMessage grpc.ServerStream