From 520ba61f19aae7b7ce81fab2b586fe73d41c129e Mon Sep 17 00:00:00 2001 From: Sherif Akoush Date: Fri, 1 Nov 2024 11:41:59 +0000 Subject: [PATCH 1/3] fix(agent): allow serial order for servers connection per instance (#6020) * enforce serial order on agent connections * add serial order for dataflow engine * move logging after acquiring lock * lint fixes * add tests for chainer subscribe * increase sleep to fix flaky test --- scheduler/pkg/agent/server.go | 14 ++- scheduler/pkg/agent/server_test.go | 131 +++++++++++++++++++- scheduler/pkg/kafka/dataflow/server.go | 24 ++-- scheduler/pkg/kafka/dataflow/server_test.go | 126 ++++++++++++++++++- 4 files changed, 283 insertions(+), 12 deletions(-) diff --git a/scheduler/pkg/agent/server.go b/scheduler/pkg/agent/server.go index 4cdb6ccaec..6744916ee9 100644 --- a/scheduler/pkg/agent/server.go +++ b/scheduler/pkg/agent/server.go @@ -112,6 +112,7 @@ type Server struct { certificateStore *seldontls.CertificateStore waiter *modelRelocatedWaiter // waiter for when we want to drain a particular server replica autoscalingServiceEnabled bool + agentMutex sync.Map // to force a serial order per agent (serverName, replicaIdx) } type SchedulerAgent interface { @@ -138,6 +139,7 @@ func NewAgentServer( scheduler: scheduler, waiter: newModelRelocatedWaiter(), autoscalingServiceEnabled: autoscalingServiceEnabled, + agentMutex: sync.Map{}, } hub.RegisterModelEventHandler( @@ -383,12 +385,20 @@ func (s *Server) ModelScalingTrigger(stream pb.AgentService_ModelScalingTriggerS func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentService_SubscribeServer) error { logger := s.logger.WithField("func", "Subscribe") + key := ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx} + + // this is forcing a serial order per agent (serverName, replicaIdx) + // in general this will make sure that a given agent disconnects fully before another agent is allowed to connect + mu, _ := s.agentMutex.LoadOrStore(key, &sync.Mutex{}) + mu.(*sync.Mutex).Lock() + defer mu.(*sync.Mutex).Unlock() + logger.Infof("Received subscribe request from %s:%d", request.ServerName, request.ReplicaIdx) fin := make(chan bool) s.mutex.Lock() - s.agents[ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}] = &AgentSubscriber{ + s.agents[key] = &AgentSubscriber{ finished: fin, stream: stream, } @@ -414,7 +424,7 @@ func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentSer case <-ctx.Done(): logger.Infof("Client replica %s:%d has disconnected", request.ServerName, request.ReplicaIdx) s.mutex.Lock() - delete(s.agents, ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}) + delete(s.agents, key) s.mutex.Unlock() s.removeServerReplicaImpl(request.GetServerName(), int(request.GetReplicaIdx())) // this is non-blocking beyond rescheduling models on removed server return nil diff --git a/scheduler/pkg/agent/server_test.go b/scheduler/pkg/agent/server_test.go index 088c6e6237..e95783e91a 100644 --- a/scheduler/pkg/agent/server_test.go +++ b/scheduler/pkg/agent/server_test.go @@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t package agent import ( + "context" "fmt" "testing" "time" @@ -17,14 +18,31 @@ import ( . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent" pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent" pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" + testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" ) +type mockScheduler struct { +} + +var _ scheduler.Scheduler = (*mockScheduler)(nil) + +func (s mockScheduler) Schedule(_ string) error { + return nil +} + +func (s mockScheduler) ScheduleFailedModels() ([]string, error) { + return nil, nil +} + type mockStore struct { models map[string]*store.ModelSnapshot } @@ -91,7 +109,7 @@ func (m *mockStore) UpdateModelState(modelKey string, version uint32, serverKey } func (m *mockStore) AddServerReplica(request *pb.AgentSubscribeRequest) error { - panic("implement me") + return nil } func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error { @@ -99,7 +117,7 @@ func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error { } func (m *mockStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) { - panic("implement me") + return nil, nil } func (m *mockStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) { @@ -943,3 +961,112 @@ func TestAutoscalingEnabled(t *testing.T) { } } + +func TestSubscribe(t *testing.T) { + log.SetLevel(log.DebugLevel) + g := NewGomegaWithT(t) + + type ag struct { + id uint32 + doClose bool + } + type test struct { + name string + agents []ag + expectedAgentsCount int + expectedAgentsCountAfterClose int + } + tests := []test{ + { + name: "simple", + agents: []ag{ + {1, true}, {2, true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 0, + }, + { + name: "simple - no close", + agents: []ag{ + {1, true}, {2, false}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 1, + }, + { + name: "duplicates", + agents: []ag{ + {1, true}, {1, false}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 1, + }, + { + name: "duplicates with all close", + agents: []ag{ + {1, true}, {1, true}, {1, true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + } + + getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn { + conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + grpcClient := agent.NewAgentServiceClient(conn) + _, _ = grpcClient.Subscribe( + context, + &agent.AgentSubscribeRequest{ + ServerName: "dummy", + ReplicaIdx: id, + ReplicaConfig: &agent.ReplicaConfig{}, + Shared: true, + AvailableMemoryBytes: 0, + }, + ) + return conn + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + logger := log.New() + eventHub, err := coordinator.NewEventHub(logger) + g.Expect(err).To(BeNil()) + server := NewAgentServer(logger, &mockStore{}, mockScheduler{}, eventHub, false) + port, err := testing_utils.GetFreePortForTest() + if err != nil { + t.Fatal(err) + } + err = server.startServer(uint(port), false) + if err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + + streams := make([]*grpc.ClientConn, 0) + for _, a := range test.agents { + go func(id uint32) { + conn := getStream(id, context.Background(), port) + streams = append(streams, conn) + }(a.id) + } + + time.Sleep(500 * time.Millisecond) + + g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount)) + + for idx, s := range streams { + go func(idx int, s *grpc.ClientConn) { + if test.agents[idx].doClose { + s.Close() + } + }(idx, s) + } + + time.Sleep(10 * time.Second) + + g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose)) + + server.StopAgentStreams() + }) + } +} diff --git a/scheduler/pkg/kafka/dataflow/server.go b/scheduler/pkg/kafka/dataflow/server.go index cdb2ce1969..aeed2ce299 100644 --- a/scheduler/pkg/kafka/dataflow/server.go +++ b/scheduler/pkg/kafka/dataflow/server.go @@ -44,6 +44,7 @@ type ChainerServer struct { pipelineHandler pipeline.PipelineHandler topicNamer *kafka.TopicNamer loadBalancer util.LoadBalancer + chainerMutex sync.Map chainer.UnimplementedChainerServer } @@ -66,6 +67,7 @@ func NewChainerServer(logger log.FieldLogger, eventHub *coordinator.EventHub, pi pipelineHandler: pipelineHandler, topicNamer: topicNamer, loadBalancer: loadBalancer, + chainerMutex: sync.Map{}, } eventHub.RegisterPipelineEventHandler( @@ -125,17 +127,25 @@ func (c *ChainerServer) PipelineUpdateEvent(ctx context.Context, message *chaine func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscriptionRequest, stream chainer.Chainer_SubscribePipelineUpdatesServer) error { logger := c.logger.WithField("func", "SubscribePipelineStatus") + + key := req.GetName() + // this is forcing a serial order per dataflow-engine + // in general this will make sure that a given dataflow-engine disconnects fully before another dataflow-engine is allowed to connect + mu, _ := c.chainerMutex.LoadOrStore(key, &sync.Mutex{}) + mu.(*sync.Mutex).Lock() + defer mu.(*sync.Mutex).Unlock() + logger.Infof("Received subscribe request from %s", req.GetName()) fin := make(chan bool) c.mu.Lock() - c.streams[req.Name] = &ChainerSubscription{ - name: req.Name, + c.streams[key] = &ChainerSubscription{ + name: key, stream: stream, fin: fin, } - c.loadBalancer.AddServer(req.Name) + c.loadBalancer.AddServer(key) c.mu.Unlock() // Handle addition of new server @@ -148,13 +158,13 @@ func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscripti for { select { case <-fin: - logger.Infof("Closing stream for %s", req.GetName()) + logger.Infof("Closing stream for %s", key) return nil case <-ctx.Done(): - logger.Infof("Stream disconnected %s", req.GetName()) + logger.Infof("Stream disconnected %s", key) c.mu.Lock() - c.loadBalancer.RemoveServer(req.Name) - delete(c.streams, req.Name) + c.loadBalancer.RemoveServer(key) + delete(c.streams, key) c.mu.Unlock() // Handle removal of server c.rebalance() diff --git a/scheduler/pkg/kafka/dataflow/server_test.go b/scheduler/pkg/kafka/dataflow/server_test.go index d7f0860f4b..ee862e53c1 100644 --- a/scheduler/pkg/kafka/dataflow/server_test.go +++ b/scheduler/pkg/kafka/dataflow/server_test.go @@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t package dataflow import ( + "context" "fmt" "os" "testing" @@ -18,11 +19,13 @@ import ( . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "github.com/seldonio/seldon-core/apis/go/v2/mlops/chainer" "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" + testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka" "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka/config" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" @@ -321,7 +324,7 @@ func TestPipelineRollingUpgradeEvents(t *testing.T) { } // to allow events to propagate - time.Sleep(500 * time.Millisecond) + time.Sleep(700 * time.Millisecond) if test.connection { if test.loadReqV2 != nil { @@ -634,6 +637,127 @@ func TestPipelineRebalance(t *testing.T) { } } +func TestPipelineSubscribe(t *testing.T) { + g := NewGomegaWithT(t) + + type ag struct { + id uint32 + doClose bool + } + + type test struct { + name string + agents []ag + expectedAgentsCount int + expectedAgentsCountAfterClose int + } + + tests := []test{ + { + name: "single connection", + agents: []ag{ + {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - one not closed", + agents: []ag{ + {id: 1, doClose: false}, {id: 2, doClose: true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 1, + }, + { + name: "multiple connection - not closed", + agents: []ag{ + {id: 1, doClose: false}, {id: 2, doClose: false}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 2, + }, + { + name: "multiple connection - closed", + agents: []ag{ + {id: 1, doClose: true}, {id: 2, doClose: true}, + }, + expectedAgentsCount: 2, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - duplicate", + agents: []ag{ + {id: 1, doClose: true}, {id: 1, doClose: true}, {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 0, + }, + { + name: "multiple connection - duplicate not closed", + agents: []ag{ + {id: 1, doClose: true}, {id: 1, doClose: false}, {id: 1, doClose: true}, + }, + expectedAgentsCount: 1, + expectedAgentsCountAfterClose: 1, + }, + } + + getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn { + conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials())) + grpcClient := chainer.NewChainerClient(conn) + _, _ = grpcClient.SubscribePipelineUpdates( + context, + &chainer.PipelineSubscriptionRequest{ + Name: fmt.Sprintf("agent-%d", id), + }, + ) + return conn + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serverName := "dummy" + s, _ := createTestScheduler(t, serverName) + port, err := testing_utils.GetFreePortForTest() + if err != nil { + t.Fatal(err) + } + go func() { + _ = s.StartGrpcServer(uint(port)) + }() + + time.Sleep(100 * time.Millisecond) + + streams := make([]*grpc.ClientConn, 0) + for _, a := range test.agents { + go func(id uint32) { + conn := getStream(id, context.Background(), port) + streams = append(streams, conn) + }(a.id) + } + + time.Sleep(700 * time.Millisecond) + + g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCount)) + + for idx, s := range streams { + go func(idx int, s *grpc.ClientConn) { + if test.agents[idx].doClose { + s.Close() + } + }(idx, s) + } + + time.Sleep(10 * time.Second) + + g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCountAfterClose)) + + s.StopSendPipelineEvents() + }) + } +} + type stubChainerServer struct { msgs chan *chainer.PipelineUpdateMessage grpc.ServerStream From cb5cfbdcfd076e6a000692e087d1ee42ffc0752a Mon Sep 17 00:00:00 2001 From: Sherif Akoush Date: Fri, 1 Nov 2024 14:46:13 +0000 Subject: [PATCH 2/3] fix(ci): fix flaky test (#6022) * fix flaky test * cap execution of test in case of failures --- scheduler/pkg/agent/server_test.go | 19 +++++++++++++++---- scheduler/pkg/kafka/dataflow/server_test.go | 20 +++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/scheduler/pkg/agent/server_test.go b/scheduler/pkg/agent/server_test.go index e95783e91a..75b6adf56b 100644 --- a/scheduler/pkg/agent/server_test.go +++ b/scheduler/pkg/agent/server_test.go @@ -12,6 +12,7 @@ package agent import ( "context" "fmt" + "sync" "testing" "time" @@ -1042,16 +1043,23 @@ func TestSubscribe(t *testing.T) { } time.Sleep(100 * time.Millisecond) + mu := sync.Mutex{} streams := make([]*grpc.ClientConn, 0) for _, a := range test.agents { go func(id uint32) { conn := getStream(id, context.Background(), port) + mu.Lock() streams = append(streams, conn) + mu.Unlock() }(a.id) } - time.Sleep(500 * time.Millisecond) - + maxCount := 10 + count := 0 + for len(server.agents) != test.expectedAgentsCount && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount)) for idx, s := range streams { @@ -1062,8 +1070,11 @@ func TestSubscribe(t *testing.T) { }(idx, s) } - time.Sleep(10 * time.Second) - + count = 0 + for len(server.agents) != test.expectedAgentsCountAfterClose && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose)) server.StopAgentStreams() diff --git a/scheduler/pkg/kafka/dataflow/server_test.go b/scheduler/pkg/kafka/dataflow/server_test.go index ee862e53c1..e52780a9a1 100644 --- a/scheduler/pkg/kafka/dataflow/server_test.go +++ b/scheduler/pkg/kafka/dataflow/server_test.go @@ -13,6 +13,7 @@ import ( "context" "fmt" "os" + "sync" "testing" "time" @@ -639,7 +640,6 @@ func TestPipelineRebalance(t *testing.T) { func TestPipelineSubscribe(t *testing.T) { g := NewGomegaWithT(t) - type ag struct { id uint32 doClose bool @@ -729,16 +729,23 @@ func TestPipelineSubscribe(t *testing.T) { time.Sleep(100 * time.Millisecond) + mu := sync.Mutex{} streams := make([]*grpc.ClientConn, 0) for _, a := range test.agents { go func(id uint32) { conn := getStream(id, context.Background(), port) + mu.Lock() streams = append(streams, conn) + mu.Unlock() }(a.id) } - time.Sleep(700 * time.Millisecond) - + maxCount := 10 + count := 0 + for len(s.streams) != test.expectedAgentsCount && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCount)) for idx, s := range streams { @@ -749,8 +756,11 @@ func TestPipelineSubscribe(t *testing.T) { }(idx, s) } - time.Sleep(10 * time.Second) - + count = 0 + for len(s.streams) != test.expectedAgentsCountAfterClose && count < maxCount { + time.Sleep(100 * time.Millisecond) + count++ + } g.Expect(len(s.streams)).To(Equal(test.expectedAgentsCountAfterClose)) s.StopSendPipelineEvents() From 2c79e778720d68bc2bc070b8ebf77d3b60d13542 Mon Sep 17 00:00:00 2001 From: Lucian Carata Date: Fri, 1 Nov 2024 15:07:17 +0000 Subject: [PATCH 3/3] fix(dataflow): make each replica use unique subscription names (#6021) Following #6020, it was no longer possible to have multiple replicas of dataflow-engine subscribing simultaneously to the scheduler, because all were connecting with the same subscriber name, and a lock was added per name, first waiting the disconnection of the old subscriber before allowing a new one to progress. We update the dataflow-engine code so that each replica connects with its own hostname as the subscriber name. If the hostname can not be determined, we subscribe with the name seldon-dataflow-engine- followed by the canonical string representation of a UUID v4. The subscriber name can also be explicitly controlled by passing the --dataflow-replica-id argument or the DATAFLOW_REPLICA_ID environment variable, wich will take precedence, in that order, to setting the value as the hostname. --- .../src/main/kotlin/io/seldon/dataflow/Cli.kt | 23 +++++++++++++- .../main/kotlin/io/seldon/dataflow/Main.kt | 4 ++- .../src/main/resources/local.properties | 1 + .../test/kotlin/io/seldon/dataflow/CliTest.kt | 31 +++++++++++++++++++ 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt index 1e5485aaea..94a5e393a5 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt @@ -11,6 +11,7 @@ package io.seldon.dataflow import com.natpryce.konfig.CommandLineOption import com.natpryce.konfig.Configuration +import com.natpryce.konfig.ConfigurationMap import com.natpryce.konfig.ConfigurationProperties import com.natpryce.konfig.EnvironmentVariables import com.natpryce.konfig.Key @@ -25,6 +26,8 @@ import io.klogging.Level import io.klogging.noCoLogger import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms import io.seldon.dataflow.kafka.security.KafkaSecurityProtocols +import java.net.InetAddress +import java.util.UUID object Cli { private const val ENV_VAR_PREFIX = "SELDON_" @@ -34,6 +37,7 @@ object Cli { val logLevelApplication = Key("log.level.app", enumType(*Level.values())) val logLevelKafka = Key("log.level.kafka", enumType(*Level.values())) val namespace = Key("pod.namespace", stringType) + val dataflowReplicaId = Key("dataflow.replica.id", stringType) // Seldon components val upstreamHost = Key("upstream.host", stringType) @@ -75,6 +79,7 @@ object Cli { logLevelApplication, logLevelKafka, namespace, + dataflowReplicaId, upstreamHost, upstreamPort, kafkaBootstrapServers, @@ -105,10 +110,26 @@ object Cli { fun configWith(rawArgs: Array): Configuration { val fromProperties = ConfigurationProperties.fromResource("local.properties") + val fromSystem = getSystemConfig() val fromEnv = EnvironmentVariables(prefix = ENV_VAR_PREFIX) val fromArgs = parseArguments(rawArgs) - return fromArgs overriding fromEnv overriding fromProperties + return fromArgs overriding fromEnv overriding fromSystem overriding fromProperties + } + + private fun getSystemConfig(): Configuration { + val dataflowIdPair = this.dataflowReplicaId to getNewDataflowId() + return ConfigurationMap(dataflowIdPair) + } + + fun getNewDataflowId(assignRandomUuid: Boolean = false): String { + if (!assignRandomUuid) { + try { + return InetAddress.getLocalHost().hostName + } catch (_: Exception) { + } + } + return "seldon-dataflow-engine-" + UUID.randomUUID().toString() } private fun parseArguments(rawArgs: Array): Configuration { diff --git a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt index 8d4a899eaa..b064a974f5 100644 --- a/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt +++ b/scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Main.kt @@ -102,9 +102,11 @@ object Main { describeRetries = config[Cli.topicDescribeRetries], describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis], ) + val subscriberId = config[Cli.dataflowReplicaId] + val subscriber = PipelineSubscriber( - "seldon-dataflow-engine", + subscriberId, kafkaProperties, kafkaAdminProperties, kafkaStreamsParams, diff --git a/scheduler/data-flow/src/main/resources/local.properties b/scheduler/data-flow/src/main/resources/local.properties index 46a7218380..68b3bd408d 100644 --- a/scheduler/data-flow/src/main/resources/local.properties +++ b/scheduler/data-flow/src/main/resources/local.properties @@ -1,5 +1,6 @@ log.level.app=INFO log.level.kafka=WARN +dataflow.replica.id=seldon-dataflow-engine kafka.bootstrap.servers=localhost:9092 kafka.consumer.prefix= kafka.security.protocol=PLAINTEXT diff --git a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt index 9011ff3d4c..52a97fa4b5 100644 --- a/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt +++ b/scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt @@ -16,9 +16,15 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.Arguments.arguments import org.junit.jupiter.params.provider.MethodSource import strikt.api.expectCatching +import strikt.api.expectThat +import strikt.assertions.hasLength import strikt.assertions.isEqualTo +import strikt.assertions.isNotEqualTo import strikt.assertions.isSuccess +import strikt.assertions.startsWith +import java.util.UUID import java.util.stream.Stream +import kotlin.test.Test internal class CliTest { @DisplayName("Passing auth mechanism via cli argument") @@ -36,6 +42,31 @@ internal class CliTest { .isEqualTo(expectedMechanism) } + @Test + fun `should handle dataflow replica id`() { + val cliDefault = Cli.configWith(arrayOf()) + val testReplicaId = "dataflow-id-1" + val cli = Cli.configWith(arrayOf("--dataflow-replica-id", testReplicaId)) + + expectThat(cliDefault[Cli.dataflowReplicaId]) { + isNotEqualTo("seldon-dataflow-engine") + } + expectThat(cli[Cli.dataflowReplicaId]) { + isEqualTo(testReplicaId) + } + + // test random Uuid (v4) + val expectedReplicaIdPrefix = "seldon-dataflow-engine-" + val uuidStringLength = 36 + val randomReplicaUuid = Cli.getNewDataflowId(true) + expectThat(randomReplicaUuid) { + startsWith(expectedReplicaIdPrefix) + hasLength(expectedReplicaIdPrefix.length + uuidStringLength) + } + expectCatching { UUID.fromString(randomReplicaUuid.removePrefix(expectedReplicaIdPrefix)) } + .isSuccess() + } + companion object { @JvmStatic private fun saslMechanisms(): Stream {