Skip to content

Commit

Permalink
Fix annotation ingestion for Pulsar-backed scheduler (#2607)
Browse files Browse the repository at this point in the history
* Cleanup api.go

* Ingest annotations, cleanup

* Cleanup

* Lint

* Fix test
  • Loading branch information
severinson committed Jun 26, 2023
1 parent fac9411 commit f2cdeab
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 117 deletions.
133 changes: 67 additions & 66 deletions internal/scheduler/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,25 @@ import (
"github.com/armadaproject/armada/pkg/executorapi"
)

// ExecutorApi is a gRPC service that exposes functionality required by the armada executors
// ExecutorApi is the gRPC service executors use to synchronise their state with that of the scheduler.
type ExecutorApi struct {
producer pulsar.Producer
jobRepository database.JobRepository
executorRepository database.ExecutorRepository
legacyExecutorRepository database.ExecutorRepository
allowedPriorities []int32 // allowed priority classes
maxJobsPerCall uint // maximum number of jobs that will be leased in a single call
maxPulsarMessageSize uint // maximum sizer of pulsar messages produced
nodeIdLabel string
// Used to send Pulsar messages when, e.g., executors report a job has finished.
producer pulsar.Producer
// Interface to the component storing job information, such as which jobs are leased to a particular executor.
jobRepository database.JobRepository
// Interface to the component storing executor information, such as which when we last heard from an executor.
executorRepository database.ExecutorRepository
// Like executorRepository
legacyExecutorRepository database.ExecutorRepository
// Allowed priority class priorities.
allowedPriorities []int32
// Max number of job leases sent per call to LeaseJobRuns.
maxJobsPerCall uint
// Max size of Pulsar messages produced.
maxPulsarMessageSizeBytes uint
// See scheduling config.
nodeIdLabel string
// See scheduling config.
priorityClassNameOverride *string
clock clock.Clock
}
Expand All @@ -46,6 +55,7 @@ func NewExecutorApi(producer pulsar.Producer,
maxJobsPerCall uint,
nodeIdLabel string,
priorityClassNameOverride *string,
maxPulsarMessageSizeBytes uint,
) (*ExecutorApi, error) {
if len(allowedPriorities) == 0 {
return nil, errors.New("allowedPriorities cannot be empty")
Expand All @@ -60,86 +70,79 @@ func NewExecutorApi(producer pulsar.Producer,
legacyExecutorRepository: legacyExecutorRepository,
allowedPriorities: allowedPriorities,
maxJobsPerCall: maxJobsPerCall,
maxPulsarMessageSize: 1024 * 1024 * 2,
maxPulsarMessageSizeBytes: maxPulsarMessageSizeBytes,
nodeIdLabel: nodeIdLabel,
priorityClassNameOverride: priorityClassNameOverride,
clock: clock.RealClock{},
}, nil
}

// LeaseJobRuns performs the following actions:
// - Stores the request in postgres so that the scheduler can use the job + capacity information in the next scheduling round
// - Determines if any of the job runs in the request are no longer active and should be cancelled
// - Determines if any new job runs should be leased to the executor
// LeaseJobRuns reconciles the state of the executor with that of the scheduler. Specifically it:
// 1. Stores job and capacity information received from the executor to make it available to the scheduler.
// 2. Notifies the executor if any of its jobs are no longer active, e.g., due to being preempted by the scheduler.
// 3. Transfers any jobs scheduled on this executor cluster that the executor don't already have.
func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRunsServer) error {
ctx := stream.Context()
log := ctxlogrus.Extract(ctx)
// Receive once to get info necessary to get jobs to lease.
req, err := stream.Recv()
if err != nil {
return errors.WithStack(err)
}

log.Infof("Handling lease request for executor %s", req.ExecutorId)
ctx := stream.Context()
log := ctxlogrus.Extract(ctx)
log = log.WithField("executor", req.ExecutorId)

// store the executor state for use by the scheduler
executorState := srv.createExecutorState(ctx, req)
if err = srv.executorRepository.StoreExecutor(stream.Context(), executorState); err != nil {
executor := srv.executorFromLeaseRequest(ctx, req)
if err := srv.executorRepository.StoreExecutor(ctx, executor); err != nil {
return err
}

// store the executor state for the legacy executor to use
if err = srv.legacyExecutorRepository.StoreExecutor(stream.Context(), executorState); err != nil {
if err = srv.legacyExecutorRepository.StoreExecutor(ctx, executor); err != nil {
return err
}

requestRuns, err := extractRunIds(req)
requestRuns, err := runIdsFromLeaseRequest(req)
if err != nil {
return err
}
log.Debugf("Executor is currently aware of %d job runs", len(requestRuns))

runsToCancel, err := srv.jobRepository.FindInactiveRuns(stream.Context(), requestRuns)
runsToCancel, err := srv.jobRepository.FindInactiveRuns(ctx, requestRuns)
if err != nil {
return err
}
log.Debugf("Detected %d runs that need cancelling", len(runsToCancel))

// Fetch new leases from the db
leases, err := srv.jobRepository.FetchJobRunLeases(stream.Context(), req.ExecutorId, srv.maxJobsPerCall, requestRuns)
newRuns, err := srv.jobRepository.FetchJobRunLeases(ctx, req.ExecutorId, srv.maxJobsPerCall, requestRuns)
if err != nil {
return err
}
log.Infof(
"executor currently has %d job runs; sending %d cancellations and %d new runs",
len(requestRuns), len(runsToCancel), len(newRuns),
)

// if necessary send a list of runs to cancel
// Send any runs that should be cancelled.
if len(runsToCancel) > 0 {
err = stream.Send(&executorapi.LeaseStreamMessage{
if err := stream.Send(&executorapi.LeaseStreamMessage{
Event: &executorapi.LeaseStreamMessage_CancelRuns{
CancelRuns: &executorapi.CancelRuns{
JobRunIdsToCancel: util.Map(runsToCancel, func(x uuid.UUID) *armadaevents.Uuid {
return armadaevents.ProtoUuidFromUuid(x)
}),
},
},
})

if err != nil {
}); err != nil {
return errors.WithStack(err)
}
}

// Now send any leases
// Send any scheduled jobs the executor doesn't already have.
decompressor := compress.NewZlibDecompressor()
for _, lease := range leases {
for _, lease := range newRuns {
submitMsg := &armadaevents.SubmitJob{}
err = decompressAndMarshall(lease.SubmitMessage, decompressor, submitMsg)
if err != nil {
if err := unmarshalFromCompressedBytes(lease.SubmitMessage, decompressor, submitMsg); err != nil {
return err
}
if srv.priorityClassNameOverride != nil {
srv.setPriorityClassName(submitMsg, *srv.priorityClassNameOverride)
}
srv.addNodeSelector(submitMsg, lease.Node)
srv.addNodeIdSelector(submitMsg, lease.Node)

var groups []string
if len(lease.Groups) > 0 {
Expand All @@ -148,7 +151,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns
return err
}
}
err = stream.Send(&executorapi.LeaseStreamMessage{
err := stream.Send(&executorapi.LeaseStreamMessage{
Event: &executorapi.LeaseStreamMessage_Lease{
Lease: &executorapi.JobRunLease{
JobRunId: armadaevents.ProtoUuidFromUuid(lease.RunID),
Expand Down Expand Up @@ -189,11 +192,10 @@ func (srv *ExecutorApi) setPriorityClassName(job *armadaevents.SubmitJob, priori
}
}

func (srv *ExecutorApi) addNodeSelector(job *armadaevents.SubmitJob, nodeId string) {
func (srv *ExecutorApi) addNodeIdSelector(job *armadaevents.SubmitJob, nodeId string) {
if job == nil || nodeId == "" {
return
}

if job.MainObject != nil {
switch typed := job.MainObject.Object.(type) {
case *armadaevents.KubernetesMainObject_PodSpec:
Expand All @@ -207,9 +209,10 @@ func addNodeSelector(podSpec *armadaevents.PodSpecWithAvoidList, key string, val
return
}
if podSpec.PodSpec.NodeSelector == nil {
podSpec.PodSpec.NodeSelector = make(map[string]string, 1)
podSpec.PodSpec.NodeSelector = map[string]string{key: value}
} else {
podSpec.PodSpec.NodeSelector[key] = value
}
podSpec.PodSpec.NodeSelector[key] = value
}

func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityClassName string) {
Expand All @@ -219,19 +222,19 @@ func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityCl
podSpec.PodSpec.PriorityClassName = priorityClassName
}

// ReportEvents publishes all events to pulsar. The events are compacted for more efficient publishing
// ReportEvents publishes all events to Pulsar. The events are compacted for more efficient publishing.
func (srv *ExecutorApi) ReportEvents(ctx context.Context, list *executorapi.EventList) (*types.Empty, error) {
err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSize, schedulers.Pulsar)
err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSizeBytes, schedulers.Pulsar)
return &types.Empty{}, err
}

// createExecutorState extracts a schedulerobjects.Executor from the requesrt
func (srv *ExecutorApi) createExecutorState(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor {
// executorFromLeaseRequest extracts a schedulerobjects.Executor from the request.
func (srv *ExecutorApi) executorFromLeaseRequest(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor {
log := ctxlogrus.Extract(ctx)
nodes := make([]*schedulerobjects.Node, 0, len(req.Nodes))
now := srv.clock.Now().UTC()
for _, nodeInfo := range req.Nodes {
node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, srv.clock.Now().UTC())
if err != nil {
if node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, now); err != nil {
logging.WithStacktrace(log, err).Warnf(
"skipping node %s from executor %s", nodeInfo.GetName(), req.GetExecutorId(),
)
Expand All @@ -244,37 +247,35 @@ func (srv *ExecutorApi) createExecutorState(ctx context.Context, req *executorap
Pool: req.Pool,
Nodes: nodes,
MinimumJobSize: schedulerobjects.ResourceList{Resources: req.MinimumJobSize},
LastUpdateTime: srv.clock.Now().UTC(),
UnassignedJobRuns: util.Map(req.UnassignedJobRunIds, func(x armadaevents.Uuid) string {
return strings.ToLower(armadaevents.UuidFromProtoUuid(&x).String())
LastUpdateTime: now,
UnassignedJobRuns: util.Map(req.UnassignedJobRunIds, func(jobId armadaevents.Uuid) string {
return strings.ToLower(armadaevents.UuidFromProtoUuid(&jobId).String())
}),
}
}

// extractRunIds extracts all the job runs contained in the executor request
func extractRunIds(req *executorapi.LeaseRequest) ([]uuid.UUID, error) {
runIds := make([]uuid.UUID, 0)
// add all runids from nodes
// runIdsFromLeaseRequest returns the ids of all runs in a lease request, including any not yet assigned to a node.
func runIdsFromLeaseRequest(req *executorapi.LeaseRequest) ([]uuid.UUID, error) {
runIds := make([]uuid.UUID, 0, 256)
for _, node := range req.Nodes {
for runIdStr := range node.RunIdsByState {
runId, err := uuid.Parse(runIdStr)
if err != nil {
if runId, err := uuid.Parse(runIdStr); err != nil {
return nil, errors.WithStack(err)
} else {
runIds = append(runIds, runId)
}
runIds = append(runIds, runId)
}
}
// add all unassigned runids
for _, runId := range req.UnassignedJobRunIds {
runIds = append(runIds, armadaevents.UuidFromProtoUuid(&runId))
}
return runIds, nil
}

func decompressAndMarshall(b []byte, decompressor compress.Decompressor, msg proto.Message) error {
decompressed, err := decompressor.Decompress(b)
func unmarshalFromCompressedBytes(bytes []byte, decompressor compress.Decompressor, msg proto.Message) error {
decompressedBytes, err := decompressor.Decompress(bytes)
if err != nil {
return err
}
return proto.Unmarshal(decompressed, msg)
return proto.Unmarshal(decompressedBytes, msg)
}
4 changes: 3 additions & 1 deletion internal/scheduler/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
mockLegacyExecutorRepository := schedulermocks.NewMockExecutorRepository(ctrl)
mockStream := schedulermocks.NewMockExecutorApi_LeaseJobRunsServer(ctrl)

runIds, err := extractRunIds(tc.request)
runIds, err := runIdsFromLeaseRequest(tc.request)
require.NoError(t, err)

// set up mocks
Expand Down Expand Up @@ -204,6 +204,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
maxJobsPerCall,
"kubernetes.io/hostname",
nil,
4*1024*1024,
)
require.NoError(t, err)
server.clock = testClock
Expand Down Expand Up @@ -331,6 +332,7 @@ func TestExecutorApi_Publish(t *testing.T) {
100,
"kubernetes.io/hostname",
nil,
4*1024*1024,
)

require.NoError(t, err)
Expand Down
12 changes: 6 additions & 6 deletions internal/scheduler/jobdb/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@ import (

// Job is the scheduler-internal representation of a job.
type Job struct {
// String representation of the job id
// String representation of the job id.
id string
// Name of the queue this job belongs to.
queue string
// Jobset the job belongs to
// We store this as it's needed for sending job event messages
// Jobset the job belongs to.
// We store this as it's needed for sending job event messages.
jobset string
// Per-queue priority of this job.
priority uint32
// Requested per queue priority of this job.
// This is used when syncing the postgres database with the scheduler-internal database
// This is used when syncing the postgres database with the scheduler-internal database.
requestedPriority uint32
// Logical timestamp indicating the order in which jobs are submitted.
// Jobs with identical Queue and Priority are sorted by this.
created int64
// True if the job is currently queued.
// If this is set then the job will not be considered for scheduling
// If this is set then the job will not be considered for scheduling.
queued bool
// The current version of the queued state
// The current version of the queued state.
queuedVersion int32
// Scheduling requirements of this job.
jobSchedulingInfo *schedulerobjects.JobSchedulingInfo
Expand Down
12 changes: 5 additions & 7 deletions internal/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade
}
events = append(events, expirationEvents...)

// Schedule jobs.
if s.clock.Now().Sub(s.previousSchedulingRoundEnd) > s.schedulePeriod {
// Schedule jobs.
overallSchedulerResult, err := s.schedulingAlgo.Schedule(ctx, txn, s.jobDb)
if err != nil {
return err
Expand All @@ -222,8 +222,6 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade
}
events = append(events, resultEvents...)
s.previousSchedulingRoundEnd = s.clock.Now()
} else {
log.Infof("skipping scheduling new jobs this cycle as a scheduling round ran less than %s ago", s.schedulePeriod)
}

// Publish to Pulsar.
Expand Down Expand Up @@ -264,7 +262,7 @@ func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) {
// Try and retrieve the job from the jobDb. If it doesn't exist then create it.
job := s.jobDb.GetById(txn, dbJob.JobID)
if job == nil {
job, err = s.createSchedulerJob(&dbJob)
job, err = s.schedulerJobFromDatabaseJob(&dbJob)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -817,8 +815,8 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura
}
}

// createSchedulerJob creates a new scheduler job from a database job.
func (s *Scheduler) createSchedulerJob(dbJob *database.Job) (*jobdb.Job, error) {
// schedulerJobFromDatabaseJob creates a new scheduler job from a database job.
func (s *Scheduler) schedulerJobFromDatabaseJob(dbJob *database.Job) (*jobdb.Job, error) {
schedulingInfo := &schedulerobjects.JobSchedulingInfo{}
err := proto.Unmarshal(dbJob.SchedulingInfo, schedulingInfo)
if err != nil {
Expand Down Expand Up @@ -892,7 +890,7 @@ func updateSchedulerRun(run *jobdb.JobRun, dbRun *database.Run) *jobdb.JobRun {
return run
}

// updateSchedulerJob updates the scheduler job (in-place) to match the database job
// updateSchedulerJob updates the scheduler job in-place to match the database job.
func updateSchedulerJob(job *jobdb.Job, dbJob *database.Job) (*jobdb.Job, error) {
if dbJob.CancelRequested && !job.CancelRequested() {
job = job.WithCancelRequested(true)
Expand Down
3 changes: 2 additions & 1 deletion internal/scheduler/schedulerapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func Run(config schedulerconfig.Configuration) error {
defer grpcServer.GracefulStop()
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.Grpc.Port))
if err != nil {
return errors.WithMessage(err, "error setting up grpc server")
return errors.WithMessage(err, "error setting up gRPC server")
}
allowedPcs := config.Scheduling.Preemption.AllowedPriorities()
executorServer, err := NewExecutorApi(
Expand All @@ -127,6 +127,7 @@ func Run(config schedulerconfig.Configuration) error {
config.Scheduling.MaximumJobsToSchedule,
config.Scheduling.Preemption.NodeIdLabel,
config.Scheduling.Preemption.PriorityClassNameOverride,
config.Pulsar.MaxAllowedMessageSize,
)
if err != nil {
return errors.WithMessage(err, "error creating executorApi")
Expand Down
2 changes: 1 addition & 1 deletion internal/scheduleringester/dbops.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func AppendDbOperation(ops []DbOperation, op DbOperation) []DbOperation {
break
}
}
return discardNilOps(ops) // TODO: Can be made more efficient.
return discardNilOps(ops)
}

func discardNilOps(ops []DbOperation) []DbOperation {
Expand Down
Loading

0 comments on commit f2cdeab

Please sign in to comment.