Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ci): Changes from v2 for release 2.8.5 (3) #6023

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package io.seldon.dataflow

import com.natpryce.konfig.CommandLineOption
import com.natpryce.konfig.Configuration
import com.natpryce.konfig.ConfigurationMap
import com.natpryce.konfig.ConfigurationProperties
import com.natpryce.konfig.EnvironmentVariables
import com.natpryce.konfig.Key
Expand All @@ -25,6 +26,8 @@ import io.klogging.Level
import io.klogging.noCoLogger
import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms
import io.seldon.dataflow.kafka.security.KafkaSecurityProtocols
import java.net.InetAddress
import java.util.UUID

object Cli {
private const val ENV_VAR_PREFIX = "SELDON_"
Expand All @@ -34,6 +37,7 @@ object Cli {
val logLevelApplication = Key("log.level.app", enumType(*Level.values()))
val logLevelKafka = Key("log.level.kafka", enumType(*Level.values()))
val namespace = Key("pod.namespace", stringType)
val dataflowReplicaId = Key("dataflow.replica.id", stringType)

// Seldon components
val upstreamHost = Key("upstream.host", stringType)
Expand Down Expand Up @@ -75,6 +79,7 @@ object Cli {
logLevelApplication,
logLevelKafka,
namespace,
dataflowReplicaId,
upstreamHost,
upstreamPort,
kafkaBootstrapServers,
Expand Down Expand Up @@ -105,10 +110,26 @@ object Cli {

fun configWith(rawArgs: Array<String>): Configuration {
val fromProperties = ConfigurationProperties.fromResource("local.properties")
val fromSystem = getSystemConfig()
val fromEnv = EnvironmentVariables(prefix = ENV_VAR_PREFIX)
val fromArgs = parseArguments(rawArgs)

return fromArgs overriding fromEnv overriding fromProperties
return fromArgs overriding fromEnv overriding fromSystem overriding fromProperties
}

private fun getSystemConfig(): Configuration {
val dataflowIdPair = this.dataflowReplicaId to getNewDataflowId()
return ConfigurationMap(dataflowIdPair)
}

fun getNewDataflowId(assignRandomUuid: Boolean = false): String {
if (!assignRandomUuid) {
try {
return InetAddress.getLocalHost().hostName
} catch (_: Exception) {
}
}
return "seldon-dataflow-engine-" + UUID.randomUUID().toString()
}

private fun parseArguments(rawArgs: Array<String>): Configuration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ object Main {
describeRetries = config[Cli.topicDescribeRetries],
describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis],
)
val subscriberId = config[Cli.dataflowReplicaId]

val subscriber =
PipelineSubscriber(
"seldon-dataflow-engine",
subscriberId,
kafkaProperties,
kafkaAdminProperties,
kafkaStreamsParams,
Expand Down
1 change: 1 addition & 0 deletions scheduler/data-flow/src/main/resources/local.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
log.level.app=INFO
log.level.kafka=WARN
dataflow.replica.id=seldon-dataflow-engine
kafka.bootstrap.servers=localhost:9092
kafka.consumer.prefix=
kafka.security.protocol=PLAINTEXT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.Arguments.arguments
import org.junit.jupiter.params.provider.MethodSource
import strikt.api.expectCatching
import strikt.api.expectThat
import strikt.assertions.hasLength
import strikt.assertions.isEqualTo
import strikt.assertions.isNotEqualTo
import strikt.assertions.isSuccess
import strikt.assertions.startsWith
import java.util.UUID
import java.util.stream.Stream
import kotlin.test.Test

internal class CliTest {
@DisplayName("Passing auth mechanism via cli argument")
Expand All @@ -36,6 +42,31 @@ internal class CliTest {
.isEqualTo(expectedMechanism)
}

@Test
fun `should handle dataflow replica id`() {
val cliDefault = Cli.configWith(arrayOf<String>())
val testReplicaId = "dataflow-id-1"
val cli = Cli.configWith(arrayOf("--dataflow-replica-id", testReplicaId))

expectThat(cliDefault[Cli.dataflowReplicaId]) {
isNotEqualTo("seldon-dataflow-engine")
}
expectThat(cli[Cli.dataflowReplicaId]) {
isEqualTo(testReplicaId)
}

// test random Uuid (v4)
val expectedReplicaIdPrefix = "seldon-dataflow-engine-"
val uuidStringLength = 36
val randomReplicaUuid = Cli.getNewDataflowId(true)
expectThat(randomReplicaUuid) {
startsWith(expectedReplicaIdPrefix)
hasLength(expectedReplicaIdPrefix.length + uuidStringLength)
}
expectCatching { UUID.fromString(randomReplicaUuid.removePrefix(expectedReplicaIdPrefix)) }
.isSuccess()
}

companion object {
@JvmStatic
private fun saslMechanisms(): Stream<Arguments> {
Expand Down
14 changes: 12 additions & 2 deletions scheduler/pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type Server struct {
certificateStore *seldontls.CertificateStore
waiter *modelRelocatedWaiter // waiter for when we want to drain a particular server replica
autoscalingServiceEnabled bool
agentMutex sync.Map // to force a serial order per agent (serverName, replicaIdx)
}

type SchedulerAgent interface {
Expand All @@ -138,6 +139,7 @@ func NewAgentServer(
scheduler: scheduler,
waiter: newModelRelocatedWaiter(),
autoscalingServiceEnabled: autoscalingServiceEnabled,
agentMutex: sync.Map{},
}

hub.RegisterModelEventHandler(
Expand Down Expand Up @@ -383,12 +385,20 @@ func (s *Server) ModelScalingTrigger(stream pb.AgentService_ModelScalingTriggerS

func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentService_SubscribeServer) error {
logger := s.logger.WithField("func", "Subscribe")
key := ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}

// this is forcing a serial order per agent (serverName, replicaIdx)
// in general this will make sure that a given agent disconnects fully before another agent is allowed to connect
mu, _ := s.agentMutex.LoadOrStore(key, &sync.Mutex{})
mu.(*sync.Mutex).Lock()
defer mu.(*sync.Mutex).Unlock()

logger.Infof("Received subscribe request from %s:%d", request.ServerName, request.ReplicaIdx)

fin := make(chan bool)

s.mutex.Lock()
s.agents[ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}] = &AgentSubscriber{
s.agents[key] = &AgentSubscriber{
finished: fin,
stream: stream,
}
Expand All @@ -414,7 +424,7 @@ func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentSer
case <-ctx.Done():
logger.Infof("Client replica %s:%d has disconnected", request.ServerName, request.ReplicaIdx)
s.mutex.Lock()
delete(s.agents, ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx})
delete(s.agents, key)
s.mutex.Unlock()
s.removeServerReplicaImpl(request.GetServerName(), int(request.GetReplicaIdx())) // this is non-blocking beyond rescheduling models on removed server
return nil
Expand Down
142 changes: 140 additions & 2 deletions scheduler/pkg/agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,40 @@ the Change License after the Change Date as each is defined in accordance with t
package agent

import (
"context"
"fmt"
"sync"
"testing"
"time"

. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
)

type mockScheduler struct {
}

var _ scheduler.Scheduler = (*mockScheduler)(nil)

func (s mockScheduler) Schedule(_ string) error {
return nil
}

func (s mockScheduler) ScheduleFailedModels() ([]string, error) {
return nil, nil
}

type mockStore struct {
models map[string]*store.ModelSnapshot
}
Expand Down Expand Up @@ -91,15 +110,15 @@ func (m *mockStore) UpdateModelState(modelKey string, version uint32, serverKey
}

func (m *mockStore) AddServerReplica(request *pb.AgentSubscribeRequest) error {
panic("implement me")
return nil
}

func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error {
panic("implement me")
}

func (m *mockStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) {
panic("implement me")
return nil, nil
}

func (m *mockStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) {
Expand Down Expand Up @@ -943,3 +962,122 @@ func TestAutoscalingEnabled(t *testing.T) {
}

}

func TestSubscribe(t *testing.T) {
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type ag struct {
id uint32
doClose bool
}
type test struct {
name string
agents []ag
expectedAgentsCount int
expectedAgentsCountAfterClose int
}
tests := []test{
{
name: "simple",
agents: []ag{
{1, true}, {2, true},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 0,
},
{
name: "simple - no close",
agents: []ag{
{1, true}, {2, false},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates",
agents: []ag{
{1, true}, {1, false},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates with all close",
agents: []ag{
{1, true}, {1, true}, {1, true},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 0,
},
}

getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn {
conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
grpcClient := agent.NewAgentServiceClient(conn)
_, _ = grpcClient.Subscribe(
context,
&agent.AgentSubscribeRequest{
ServerName: "dummy",
ReplicaIdx: id,
ReplicaConfig: &agent.ReplicaConfig{},
Shared: true,
AvailableMemoryBytes: 0,
},
)
return conn
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := log.New()
eventHub, err := coordinator.NewEventHub(logger)
g.Expect(err).To(BeNil())
server := NewAgentServer(logger, &mockStore{}, mockScheduler{}, eventHub, false)
port, err := testing_utils.GetFreePortForTest()
if err != nil {
t.Fatal(err)
}
err = server.startServer(uint(port), false)
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)

mu := sync.Mutex{}
streams := make([]*grpc.ClientConn, 0)
for _, a := range test.agents {
go func(id uint32) {
conn := getStream(id, context.Background(), port)
mu.Lock()
streams = append(streams, conn)
mu.Unlock()
}(a.id)
}

maxCount := 10
count := 0
for len(server.agents) != test.expectedAgentsCount && count < maxCount {
time.Sleep(100 * time.Millisecond)
count++
}
g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount))

for idx, s := range streams {
go func(idx int, s *grpc.ClientConn) {
if test.agents[idx].doClose {
s.Close()
}
}(idx, s)
}

count = 0
for len(server.agents) != test.expectedAgentsCountAfterClose && count < maxCount {
time.Sleep(100 * time.Millisecond)
count++
}
g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose))

server.StopAgentStreams()
})
}
}
Loading
Loading