From 7cdf653d1da3f155768194f5ff4101481f3beb67 Mon Sep 17 00:00:00 2001 From: Aviral Singh Date: Mon, 11 Sep 2023 01:32:20 +0530 Subject: [PATCH 1/9] Add note about signed commits to Contributor documentation (#2960) * Add note about signed commits to Contributor documentation Signed-off-by: Aviral Singh * Add note about signed commits to Contributor documentation --------- Signed-off-by: Aviral Singh --- CONTRIBUTING.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index adec709fb56..78cfb8ca6ba 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -53,6 +53,42 @@ Note the names of the branch must follow proper docker names: >A tag name must be valid ASCII and may contain lowercase and uppercase letters, digits, underscores, periods and dashes. A tag name may not start with a period or a dash and may contain a maximum of 128 characters. +#### Signing Off Commits + +To enhance the integrity of contributions to the Armada repository, we've adopted the use of the DCO (Developer Certificate of Origin) plug-in. This means that for every commit you contribute via Pull Requests, you'll need to sign off your commits to certify that you have the right to submit it under the open source license used by this project. + +**Every commit in your PRs must have a "Signed-Off" attribute.** + +When committing to the repository, ensure you use the `--signoff` option with `git commit`. This will append a sign-off message at the end of the commit log to indicate that the commit has your signature. + +You sign-off by adding the following to your commit messages: + +``` +Author: Your Name +Date: Thu Feb 2 11:41:15 2018 -0800 + + This is my commit message + + Signed-off-by: Your Name +``` + +Notice the `Author` and `Signed-off-by` lines match. If they don't, the PR will +be rejected by the automated DCO check. + +Git has a `-s` command line option to do this automatically: + + git commit -s -m 'This is my commit message' + +If you forgot to do this and have not yet pushed your changes to the remote +repository, you can amend your commit with the sign-off by running + + git commit --amend -s + +This command will modify the latest commit and add the required sign-off. + +For more details checkout [DCO](https://github.com/apps/dco) + + ## Chat & Discussions Sometimes, it's good to hash things out in real time. From d616febd2e4ffd911fa9c9efbe0f7061adc4f948 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Mon, 11 Sep 2023 10:03:14 +0100 Subject: [PATCH 2/9] ArmadaContext that includes a logger (#2934) * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * compilation! * rename package * more compilation * rename to Context * embed * compilation * compilation * fix test * remove old ctxloggers * revert design doc * revert developer doc * formatting * wip * tests * don't gen * don't gen * merged master --------- Co-authored-by: Chris Martin Co-authored-by: Albin Severinson --- cmd/armada/main.go | 5 +- cmd/eventsprinter/logic/logic.go | 8 +- cmd/executor/main.go | 4 +- cmd/lookoutv2/main.go | 14 +- cmd/scheduler/cmd/migrate_database.go | 4 +- internal/armada/repository/event.go | 9 +- internal/armada/repository/event_store.go | 9 +- internal/armada/scheduling/lease_manager.go | 4 +- internal/armada/server.go | 9 +- internal/armada/server/authorization.go | 6 +- internal/armada/server/event.go | 14 +- internal/armada/server/event_test.go | 13 +- internal/armada/server/eventsprinter.go | 7 +- internal/armada/server/lease.go | 60 +++--- internal/armada/server/lease_test.go | 26 +-- internal/armada/server/reporting.go | 24 +-- internal/armada/server/submit.go | 54 +++-- internal/armada/server/submit_from_log.go | 77 +++---- .../armada/server/submit_from_log_test.go | 8 +- internal/armada/server/submit_to_log.go | 26 ++- internal/armada/server/usage.go | 4 +- internal/armada/server/usage_test.go | 10 +- internal/armadactl/analyze.go | 4 +- internal/armadactl/kube.go | 4 +- internal/armadactl/resources.go | 4 +- internal/armadactl/watch.go | 4 +- internal/binoculars/server/binoculars.go | 5 +- internal/binoculars/service/cordon.go | 8 +- internal/binoculars/service/cordon_test.go | 12 +- internal/binoculars/service/logs.go | 6 +- internal/common/app/app.go | 7 +- .../common/armadacontext/armada_context.go | 107 ++++++++++ .../armadacontext/armada_context_test.go | 89 ++++++++ .../auth/authorization/kubernetes_test.go | 5 +- internal/common/certs/cached_certificate.go | 5 +- .../common/certs/cached_certificate_test.go | 5 +- internal/common/client.go | 6 +- internal/common/database/db_testutil.go | 6 +- internal/common/database/functions.go | 10 +- internal/common/database/migrations.go | 9 +- internal/common/database/types/types.go | 40 ++-- internal/common/database/upsert.go | 10 +- internal/common/database/upsert_test.go | 11 +- internal/common/etcdhealth/etcdhealth.go | 6 +- internal/common/etcdhealth/etcdhealth_test.go | 7 +- internal/common/eventutil/eventutil.go | 4 +- .../common/eventutil/sequence_from_message.go | 193 ------------------ .../eventutil/sequence_from_message_test.go | 73 ------- internal/common/grpc/grpc.go | 6 +- .../common/healthmonitor/healthmonitor.go | 6 +- .../healthmonitor/manualhealthmonitor.go | 5 +- .../healthmonitor/multihealthmonitor.go | 8 +- internal/common/ingest/batch.go | 5 +- internal/common/ingest/batch_test.go | 14 +- internal/common/ingest/ingestion_pipeline.go | 15 +- .../common/ingest/ingestion_pipeline_test.go | 9 +- internal/common/pgkeyvalue/pgkeyvalue.go | 14 +- internal/common/pgkeyvalue/pgkeyvalue_test.go | 6 +- internal/common/pulsarutils/async.go | 10 +- internal/common/pulsarutils/async_test.go | 14 +- internal/common/pulsarutils/eventsequence.go | 9 +- .../common/pulsarutils/eventsequence_test.go | 7 +- internal/common/startup.go | 4 +- internal/common/util/context.go | 5 +- internal/common/util/retry.go | 6 +- internal/common/util/retry_test.go | 9 +- internal/eventingester/convert/conversions.go | 5 +- .../eventingester/convert/conversions_test.go | 8 +- internal/eventingester/store/eventstore.go | 4 +- .../eventingester/store/eventstore_test.go | 4 +- internal/executor/application.go | 7 +- internal/executor/context/cluster_context.go | 26 +-- .../executor/context/cluster_context_test.go | 4 +- .../context/fake/sync_cluster_context.go | 4 +- internal/executor/fake/context/context.go | 4 +- internal/executor/job/job_context.go | 4 +- .../executor/job/processors/preempt_runs.go | 4 +- .../executor/job/processors/remove_runs.go | 4 +- internal/executor/reporter/event_sender.go | 5 +- .../executor/reporter/event_sender_test.go | 4 +- internal/executor/service/job_lease.go | 8 +- internal/executor/service/job_manager.go | 6 +- internal/executor/service/job_requester.go | 4 +- .../executor/service/job_requester_test.go | 4 +- internal/executor/service/lease_requester.go | 5 +- .../executor/service/lease_requester_test.go | 10 +- .../executor/service/pod_issue_handler.go | 6 +- internal/executor/util/process.go | 6 +- internal/executor/util/process_test.go | 7 +- .../pod_utilisation_kubelet_metrics.go | 4 +- internal/lookout/repository/job_pruner.go | 9 +- internal/lookout/repository/job_sets.go | 6 +- internal/lookout/repository/jobs.go | 6 +- internal/lookout/repository/queues.go | 4 +- internal/lookout/repository/sql_repository.go | 8 +- internal/lookout/repository/utils_test.go | 4 +- internal/lookout/server/lookout.go | 8 +- internal/lookout/testutil/db_testutil.go | 4 +- .../instructions/instructions.go | 10 +- .../instructions/instructions_test.go | 42 ++-- .../lookoutingester/lookoutdb/insertion.go | 48 ++--- .../lookoutdb/insertion_test.go | 134 ++++++------ .../lookoutingesterv2/benchmark/benchmark.go | 14 +- .../instructions/instructions.go | 6 +- .../instructions/instructions_test.go | 8 +- .../lookoutingesterv2/lookoutdb/insertion.go | 38 ++-- .../lookoutdb/insertion_test.go | 134 ++++++------ internal/lookoutv2/application.go | 16 +- internal/lookoutv2/gen/restapi/doc.go | 20 +- .../gen/restapi/operations/get_health.go | 4 +- .../operations/get_health_responses.go | 6 +- .../restapi/operations/get_job_run_error.go | 4 +- .../operations/get_job_run_error_responses.go | 9 +- .../gen/restapi/operations/get_job_spec.go | 4 +- .../operations/get_job_spec_responses.go | 9 +- .../gen/restapi/operations/get_jobs.go | 4 +- .../restapi/operations/get_jobs_responses.go | 9 +- .../gen/restapi/operations/group_jobs.go | 4 +- .../operations/group_jobs_responses.go | 9 +- internal/lookoutv2/pruner/pruner.go | 11 +- internal/lookoutv2/pruner/pruner_test.go | 6 +- .../lookoutv2/repository/getjobrunerror.go | 7 +- .../repository/getjobrunerror_test.go | 6 +- internal/lookoutv2/repository/getjobs.go | 12 +- internal/lookoutv2/repository/getjobs_test.go | 126 ++++++------ internal/lookoutv2/repository/getjobspec.go | 7 +- .../lookoutv2/repository/getjobspec_test.go | 6 +- internal/lookoutv2/repository/groupjobs.go | 6 +- .../lookoutv2/repository/groupjobs_test.go | 40 ++-- internal/lookoutv2/repository/util.go | 6 +- internal/pulsartest/watch.go | 6 +- internal/scheduler/api.go | 16 +- internal/scheduler/api_test.go | 11 +- internal/scheduler/database/db.go | 2 +- internal/scheduler/database/db_pruner.go | 2 +- internal/scheduler/database/db_pruner_test.go | 8 +- .../scheduler/database/executor_repository.go | 14 +- .../database/executor_repository_test.go | 6 +- internal/scheduler/database/job_repository.go | 24 +-- .../scheduler/database/job_repository_test.go | 14 +- .../database/redis_executor_repository.go | 8 +- .../redis_executor_repository_test.go | 4 +- internal/scheduler/database/util.go | 4 +- internal/scheduler/gang_scheduler.go | 10 +- internal/scheduler/gang_scheduler_test.go | 4 +- internal/scheduler/jobiteration.go | 13 +- internal/scheduler/jobiteration_test.go | 17 +- internal/scheduler/leader.go | 20 +- internal/scheduler/leader_client_test.go | 4 +- internal/scheduler/leader_metrics.go | 4 +- internal/scheduler/leader_metrics_test.go | 5 +- .../leader_proxying_reports_server_test.go | 7 +- internal/scheduler/leader_test.go | 5 +- internal/scheduler/metrics.go | 10 +- internal/scheduler/metrics_test.go | 8 +- internal/scheduler/mocks/mock_repositories.go | 18 +- internal/scheduler/pool_assigner.go | 6 +- internal/scheduler/pool_assigner_test.go | 7 +- .../scheduler/preempting_queue_scheduler.go | 78 +++---- .../preempting_queue_scheduler_test.go | 11 +- .../scheduler/proxying_reports_server_test.go | 8 +- internal/scheduler/publisher.go | 12 +- internal/scheduler/publisher_test.go | 10 +- internal/scheduler/queue_scheduler.go | 4 +- internal/scheduler/queue_scheduler_test.go | 6 +- internal/scheduler/reports_test.go | 6 +- internal/scheduler/scheduler.go | 81 +++----- internal/scheduler/scheduler_test.go | 30 +-- internal/scheduler/schedulerapp.go | 7 +- internal/scheduler/scheduling_algo.go | 31 ++- internal/scheduler/scheduling_algo_test.go | 5 +- internal/scheduler/simulator/simulator.go | 12 +- internal/scheduler/submitcheck.go | 6 +- internal/scheduler/submitcheck_test.go | 6 +- .../scheduler/testfixtures/testfixtures.go | 7 - internal/scheduleringester/instructions.go | 4 +- internal/scheduleringester/schedulerdb.go | 12 +- .../scheduleringester/schedulerdb_test.go | 8 +- 178 files changed, 1331 insertions(+), 1418 deletions(-) create mode 100644 internal/common/armadacontext/armada_context.go create mode 100644 internal/common/armadacontext/armada_context_test.go delete mode 100644 internal/common/eventutil/sequence_from_message.go delete mode 100644 internal/common/eventutil/sequence_from_message_test.go diff --git a/cmd/armada/main.go b/cmd/armada/main.go index 688fd78c029..a5eb8751d63 100644 --- a/cmd/armada/main.go +++ b/cmd/armada/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "net/http" _ "net/http/pprof" @@ -13,11 +12,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "github.com/spf13/viper" - "golang.org/x/sync/errgroup" "github.com/armadaproject/armada/internal/armada" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" gateway "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/logging" @@ -67,7 +66,7 @@ func main() { } // Run services within an errgroup to propagate errors between services. - g, ctx := errgroup.WithContext(context.Background()) + g, ctx := armadacontext.ErrGroup(armadacontext.Background()) // Cancel the errgroup context on SIGINT and SIGTERM, // which shuts everything down gracefully. diff --git a/cmd/eventsprinter/logic/logic.go b/cmd/eventsprinter/logic/logic.go index 34de61b4d61..b7a9dab8ea7 100644 --- a/cmd/eventsprinter/logic/logic.go +++ b/cmd/eventsprinter/logic/logic.go @@ -1,7 +1,6 @@ package logic import ( - "context" "fmt" "time" @@ -9,6 +8,7 @@ import ( "github.com/gogo/protobuf/proto" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -18,7 +18,7 @@ func PrintEvents(url, topic, subscription string, verbose bool) error { fmt.Println("URL:", url) fmt.Println("Topic:", topic) fmt.Println("Subscription", subscription) - return withSetup(url, topic, subscription, func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error { + return withSetup(url, topic, subscription, func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error { // Number of active jobs. numJobs := 0 @@ -199,7 +199,7 @@ func stripPodSpec(spec *v1.PodSpec) *v1.PodSpec { } // Run action with an Armada submit client and a Pulsar producer and consumer. -func withSetup(url, topic, subscription string, action func(ctx context.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error { +func withSetup(url, topic, subscription string, action func(ctx *armadacontext.Context, producer pulsar.Producer, consumer pulsar.Consumer) error) error { pulsarClient, err := pulsar.NewClient(pulsar.ClientOptions{ URL: url, }) @@ -225,5 +225,5 @@ func withSetup(url, topic, subscription string, action func(ctx context.Context, } defer consumer.Close() - return action(context.Background(), producer, consumer) + return action(armadacontext.Background(), producer, consumer) } diff --git a/cmd/executor/main.go b/cmd/executor/main.go index ed8444fbdb4..ac6374a186c 100644 --- a/cmd/executor/main.go +++ b/cmd/executor/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "net/http" "os" "os/signal" @@ -13,6 +12,7 @@ import ( "github.com/spf13/viper" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/executor" "github.com/armadaproject/armada/internal/executor/configuration" @@ -55,7 +55,7 @@ func main() { ) defer shutdownMetricServer() - shutdown, wg := executor.StartUp(context.Background(), logrus.NewEntry(logrus.New()), config) + shutdown, wg := executor.StartUp(armadacontext.Background(), logrus.NewEntry(logrus.New()), config) go func() { <-shutdownChannel shutdown() diff --git a/cmd/lookoutv2/main.go b/cmd/lookoutv2/main.go index 3ba4a865e4d..a2d5f6be90e 100644 --- a/cmd/lookoutv2/main.go +++ b/cmd/lookoutv2/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "os" "os/signal" "syscall" @@ -12,6 +11,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/lookoutv2" "github.com/armadaproject/armada/internal/lookoutv2/configuration" @@ -36,9 +36,9 @@ func init() { pflag.Parse() } -func makeContext() (context.Context, func()) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) +func makeContext() (*armadacontext.Context, func()) { + ctx := armadacontext.Background() + ctx, cancel := armadacontext.WithCancel(ctx) c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) @@ -57,7 +57,7 @@ func makeContext() (context.Context, func()) { } } -func migrate(ctx context.Context, config configuration.LookoutV2Configuration) { +func migrate(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) { db, err := database.OpenPgxPool(config.Postgres) if err != nil { panic(err) @@ -74,7 +74,7 @@ func migrate(ctx context.Context, config configuration.LookoutV2Configuration) { } } -func prune(ctx context.Context, config configuration.LookoutV2Configuration) { +func prune(ctx *armadacontext.Context, config configuration.LookoutV2Configuration) { db, err := database.OpenPgxConn(config.Postgres) if err != nil { panic(err) @@ -92,7 +92,7 @@ func prune(ctx context.Context, config configuration.LookoutV2Configuration) { log.Infof("expireAfter: %v, batchSize: %v, timeout: %v", config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, config.PrunerConfig.Timeout) - ctxTimeout, cancel := context.WithTimeout(ctx, config.PrunerConfig.Timeout) + ctxTimeout, cancel := armadacontext.WithTimeout(ctx, config.PrunerConfig.Timeout) defer cancel() err = pruner.PruneDb(ctxTimeout, db, config.PrunerConfig.ExpireAfter, config.PrunerConfig.BatchSize, clock.RealClock{}) if err != nil { diff --git a/cmd/scheduler/cmd/migrate_database.go b/cmd/scheduler/cmd/migrate_database.go index 1564bffb9fd..22d6dc12dc3 100644 --- a/cmd/scheduler/cmd/migrate_database.go +++ b/cmd/scheduler/cmd/migrate_database.go @@ -1,13 +1,13 @@ package cmd import ( - "context" "time" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" ) @@ -43,7 +43,7 @@ func migrateDatabase(cmd *cobra.Command, _ []string) error { return errors.WithMessagef(err, "Failed to connect to database") } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout) defer cancel() return schedulerdb.Migrate(ctx, db) } diff --git a/internal/armada/repository/event.go b/internal/armada/repository/event.go index 2e05ba377c6..9df6d7a1a05 100644 --- a/internal/armada/repository/event.go +++ b/internal/armada/repository/event.go @@ -14,6 +14,7 @@ import ( "github.com/armadaproject/armada/internal/armada/repository/apimessages" "github.com/armadaproject/armada/internal/armada/repository/sequence" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" @@ -48,7 +49,7 @@ func NewEventRepository(db redis.UniversalClient) *RedisEventRepository { NumTestsPerEvictionRun: 10, } - decompressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( + decompressorPool := pool.NewObjectPool(armadacontext.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibDecompressor(), nil }), &poolConfig) @@ -134,16 +135,16 @@ func (repo *RedisEventRepository) GetLastMessageId(queue, jobSetId string) (stri func (repo *RedisEventRepository) extractEvents(msg redis.XMessage, queue, jobSetId string) ([]*api.EventMessage, error) { data := msg.Values[dataKey] bytes := []byte(data.(string)) - decompressor, err := repo.decompressorPool.BorrowObject(context.Background()) + decompressor, err := repo.decompressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, errors.WithStack(err) } - defer func(decompressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(decompressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := decompressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning decompressor to pool") } - }(repo.decompressorPool, context.Background(), decompressor) + }(repo.decompressorPool, armadacontext.Background(), decompressor) decompressedData, err := decompressor.(compress.Decompressor).Decompress(bytes) if err != nil { return nil, errors.WithStack(err) diff --git a/internal/armada/repository/event_store.go b/internal/armada/repository/event_store.go index 7241cba02ef..248a405b6a4 100644 --- a/internal/armada/repository/event_store.go +++ b/internal/armada/repository/event_store.go @@ -1,10 +1,9 @@ package repository import ( - "context" - "github.com/apache/pulsar-client-go/pulsar" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/common/schedulers" @@ -12,14 +11,14 @@ import ( ) type EventStore interface { - ReportEvents(context.Context, []*api.EventMessage) error + ReportEvents(*armadacontext.Context, []*api.EventMessage) error } type TestEventStore struct { ReceivedEvents []*api.EventMessage } -func (es *TestEventStore) ReportEvents(_ context.Context, message []*api.EventMessage) error { +func (es *TestEventStore) ReportEvents(_ *armadacontext.Context, message []*api.EventMessage) error { es.ReceivedEvents = append(es.ReceivedEvents, message...) return nil } @@ -35,7 +34,7 @@ func NewEventStore(producer pulsar.Producer, maxAllowedMessageSize uint) *Stream } } -func (n *StreamEventStore) ReportEvents(ctx context.Context, apiEvents []*api.EventMessage) error { +func (n *StreamEventStore) ReportEvents(ctx *armadacontext.Context, apiEvents []*api.EventMessage) error { if len(apiEvents) == 0 { return nil } diff --git a/internal/armada/scheduling/lease_manager.go b/internal/armada/scheduling/lease_manager.go index 9b34786af9c..6e1c6385f9f 100644 --- a/internal/armada/scheduling/lease_manager.go +++ b/internal/armada/scheduling/lease_manager.go @@ -1,12 +1,12 @@ package scheduling import ( - "context" "time" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" ) @@ -55,7 +55,7 @@ func (l *LeaseManager) ExpireLeases() { if e != nil { log.Error(e) } else { - e := l.eventStore.ReportEvents(context.Background(), []*api.EventMessage{event}) + e := l.eventStore.ReportEvents(armadacontext.Background(), []*api.EventMessage{event}) if e != nil { log.Error(e) } diff --git a/internal/armada/server.go b/internal/armada/server.go index e60567583bc..7f77b26b0d9 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -1,7 +1,6 @@ package armada import ( - "context" "fmt" "net" "time" @@ -13,7 +12,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/armadaproject/armada/internal/armada/cache" @@ -22,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" "github.com/armadaproject/armada/internal/armada/server" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/database" @@ -39,7 +38,7 @@ import ( "github.com/armadaproject/armada/pkg/client" ) -func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error { +func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healthChecks *health.MultiChecker) error { log.Info("Armada server starting") log.Infof("Armada priority classes: %v", config.Scheduling.Preemption.PriorityClasses) log.Infof("Default priority class: %s", config.Scheduling.Preemption.DefaultPriorityClass) @@ -51,9 +50,9 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks // Run all services within an errgroup to propagate errors between services. // Defer cancelling the parent context to ensure the errgroup is cancelled on return. - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := armadacontext.WithCancel(ctx) defer cancel() - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) // List of services to run concurrently. // Because we want to start services only once all input validation has been completed, diff --git a/internal/armada/server/authorization.go b/internal/armada/server/authorization.go index 1d11253d3c7..434771afcbf 100644 --- a/internal/armada/server/authorization.go +++ b/internal/armada/server/authorization.go @@ -1,10 +1,10 @@ package server import ( - "context" "fmt" "strings" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/pkg/client/queue" @@ -60,7 +60,7 @@ func MergePermissionErrors(errs ...*ErrUnauthorized) *ErrUnauthorized { // permissions required to perform some action. The error returned is of type ErrUnauthorized. // After recovering the error (using errors.As), the caller can obtain the name of the user and the // requested permission programatically via this error type. -func checkPermission(p authorization.PermissionChecker, ctx context.Context, permission permission.Permission) error { +func checkPermission(p authorization.PermissionChecker, ctx *armadacontext.Context, permission permission.Permission) error { if !p.UserHasPermission(ctx, permission) { return &ErrUnauthorized{ Principal: authorization.GetPrincipal(ctx), @@ -74,7 +74,7 @@ func checkPermission(p authorization.PermissionChecker, ctx context.Context, per func checkQueuePermission( p authorization.PermissionChecker, - ctx context.Context, + ctx *armadacontext.Context, q queue.Queue, globalPermission permission.Permission, verb queue.PermissionVerb, diff --git a/internal/armada/server/event.go b/internal/armada/server/event.go index 484f1a3a9f9..14ea0d58e18 100644 --- a/internal/armada/server/event.go +++ b/internal/armada/server/event.go @@ -13,6 +13,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/repository/sequence" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -42,7 +43,8 @@ func NewEventServer( } } -func (s *EventServer) Report(ctx context.Context, message *api.EventMessage) (*types.Empty, error) { +func (s *EventServer) Report(grpcCtx context.Context, message *api.EventMessage) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[Report] error: %s", err) } @@ -50,7 +52,8 @@ func (s *EventServer) Report(ctx context.Context, message *api.EventMessage) (*t return &types.Empty{}, s.eventStore.ReportEvents(ctx, []*api.EventMessage{message}) } -func (s *EventServer) ReportMultiple(ctx context.Context, message *api.EventList) (*types.Empty, error) { +func (s *EventServer) ReportMultiple(grpcCtx context.Context, message *api.EventList) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportMultiple] error: %s", err) } @@ -116,6 +119,7 @@ func (s *EventServer) enrichPreemptedEvent(event *api.EventMessage_Preempted, jo // GetJobSetEvents streams back all events associated with a particular job set. func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Event_GetJobSetEventsServer) error { + ctx := armadacontext.FromGrpcCtx(stream.Context()) q, err := s.queueRepository.GetQueue(request.Queue) var expected *repository.ErrQueueNotFound if errors.As(err, &expected) { @@ -124,7 +128,7 @@ func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Eve return err } - err = validateUserHasWatchPermissions(stream.Context(), s.permissions, q, request.Id) + err = validateUserHasWatchPermissions(ctx, s.permissions, q, request.Id) if err != nil { return status.Errorf(codes.PermissionDenied, "[GetJobSetEvents] %s", err) } @@ -142,7 +146,7 @@ func (s *EventServer) GetJobSetEvents(request *api.JobSetRequest, stream api.Eve return s.serveEventsFromRepository(request, s.eventRepository, stream) } -func (s *EventServer) Health(ctx context.Context, cont_ *types.Empty) (*api.HealthCheckResponse, error) { +func (s *EventServer) Health(_ context.Context, _ *types.Empty) (*api.HealthCheckResponse, error) { return &api.HealthCheckResponse{Status: api.HealthCheckResponse_SERVING}, nil } @@ -222,7 +226,7 @@ func (s *EventServer) serveEventsFromRepository(request *api.JobSetRequest, even } } -func validateUserHasWatchPermissions(ctx context.Context, permsChecker authorization.PermissionChecker, q queue.Queue, jobSetId string) error { +func validateUserHasWatchPermissions(ctx *armadacontext.Context, permsChecker authorization.PermissionChecker, q queue.Queue, jobSetId string) error { err := checkPermission(permsChecker, ctx, permissions.WatchAllEvents) var globalPermErr *ErrUnauthorized if errors.As(err, &globalPermErr) { diff --git a/internal/armada/server/event_test.go b/internal/armada/server/event_test.go index a31f24965dc..18d77478f1c 100644 --- a/internal/armada/server/event_test.go +++ b/internal/armada/server/event_test.go @@ -18,6 +18,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/internal/common/compress" @@ -30,7 +31,7 @@ func TestEventServer_Health(t *testing.T) { withEventServer( t, func(s *EventServer) { - health, err := s.Health(context.Background(), &types.Empty{}) + health, err := s.Health(armadacontext.Background(), &types.Empty{}) assert.Equal(t, health.Status, api.HealthCheckResponse_SERVING) require.NoError(t, err) }, @@ -274,7 +275,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -298,7 +299,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-all-events-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -322,7 +323,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-queue-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -344,7 +345,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { assert.NoError(t, err) principal := authorization.NewStaticPrincipal("alice", []string{"watch-events-group", "watch-queue-group"}) - ctx := authorization.WithPrincipal(context.Background(), principal) + ctx := authorization.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} err = s.GetJobSetEvents(&api.JobSetRequest{ @@ -426,7 +427,7 @@ func (s *eventStreamMock) Send(m *api.EventStreamMessage) error { func (s *eventStreamMock) Context() context.Context { if s.ctx == nil { - return context.Background() + return armadacontext.Background() } return s.ctx } diff --git a/internal/armada/server/eventsprinter.go b/internal/armada/server/eventsprinter.go index 90bbca97f83..d2ba150d6e4 100644 --- a/internal/armada/server/eventsprinter.go +++ b/internal/armada/server/eventsprinter.go @@ -9,6 +9,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/pulsarutils/pulsarrequestid" @@ -29,7 +30,7 @@ type EventsPrinter struct { } // Run the service that reads from Pulsar and updates Armada until the provided context is cancelled. -func (srv *EventsPrinter) Run(ctx context.Context) error { +func (srv *EventsPrinter) Run(ctx *armadacontext.Context) error { // Get the configured logger, or the standard logger if none is provided. var log *logrus.Entry if srv.Logger != nil { @@ -74,7 +75,7 @@ func (srv *EventsPrinter) Run(ctx context.Context) error { default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, 10*time.Second) msg, err := consumer.Receive(ctxWithTimeout) cancel() if errors.Is(err, context.DeadlineExceeded) { // expected @@ -85,7 +86,7 @@ func (srv *EventsPrinter) Run(ctx context.Context) error { break } util.RetryUntilSuccess( - context.Background(), + armadacontext.Background(), func() error { return consumer.Ack(msg) }, func(err error) { logging.WithStacktrace(log, err).Warnf("acking pulsar message failed") diff --git a/internal/armada/server/lease.go b/internal/armada/server/lease.go index 7d1d7c2abec..9a776d0e15f 100644 --- a/internal/armada/server/lease.go +++ b/internal/armada/server/lease.go @@ -10,11 +10,9 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/gogo/protobuf/types" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-multierror" pool "github.com/jolestar/go-commons-pool" "github.com/pkg/errors" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" @@ -27,6 +25,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/compress" @@ -97,6 +96,7 @@ func NewAggregatedQueueServer( TimeBetweenEvictionRuns: 0, NumTestsPerEvictionRun: 10, } + decompressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibDecompressor(), nil @@ -128,7 +128,7 @@ func NewAggregatedQueueServer( // // This function should be used instead of the LeaseJobs function in most cases. func (q *AggregatedQueueServer) StreamingLeaseJobs(stream api.AggregatedQueue_StreamingLeaseJobsServer) error { - if err := checkPermission(q.permissions, stream.Context(), permissions.ExecuteJobs); err != nil { + if err := checkPermission(q.permissions, armadacontext.FromGrpcCtx(stream.Context()), permissions.ExecuteJobs); err != nil { return err } @@ -151,7 +151,7 @@ func (q *AggregatedQueueServer) StreamingLeaseJobs(stream api.AggregatedQueue_St } // Get jobs to be leased. - jobs, err := q.getJobs(stream.Context(), req) + jobs, err := q.getJobs(armadacontext.FromGrpcCtx(stream.Context()), req) if err != nil { return err } @@ -262,14 +262,12 @@ func (repo *SchedulerJobRepositoryAdapter) GetExistingJobsByIds(ids []string) ([ return rv, nil } -func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingLeaseRequest) ([]*api.Job, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithFields(logrus.Fields{ - "function": "getJobs", - "cluster": req.ClusterId, - "pool": req.Pool, - }) - ctx = ctxlogrus.ToContext(ctx, log) +func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.StreamingLeaseRequest) ([]*api.Job, error) { + ctx = armadacontext. + WithLogFields(ctx, map[string]interface{}{ + "cluster": req.ClusterId, + "pool": req.Pool, + }) // Get the total capacity available across all clusters. usageReports, err := q.usageRepository.GetClusterUsageReports() @@ -346,7 +344,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL lastSeen, ) if err != nil { - logging.WithStacktrace(log, err).Warnf( + logging.WithStacktrace(ctx.Log, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetClusterId(), ) continue @@ -474,7 +472,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL // Give Schedule() a 3 second shorter deadline than ctx to give it a chance to finish up before ctx deadline. if deadline, ok := ctx.Deadline(); ok { var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, deadline.Add(-3*time.Second)) + ctx, cancel = armadacontext.WithDeadline(ctx, deadline.Add(-3*time.Second)) defer cancel() } @@ -558,12 +556,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL "starting scheduling with total resources %s", schedulerobjects.ResourceList{Resources: totalCapacity}.CompactString(), ) - result, err := sch.Schedule( - ctxlogrus.ToContext( - ctx, - logrus.NewEntry(logrus.New()), - ), - ) + result, err := sch.Schedule(ctx) if err != nil { return nil, err } @@ -573,7 +566,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL if q.SchedulingContextRepository != nil { sctx.ClearJobSpecs() if err := q.SchedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(log, err).Error("failed to store scheduling context") + logging.WithStacktrace(ctx.Log, err).Error("failed to store scheduling context") } } @@ -648,7 +641,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL jobIdsToDelete := util.Map(jobsToDelete, func(job *api.Job) string { return job.Id }) log.Infof("deleting preempted jobs: %v", jobIdsToDelete) if deletionResult, err := q.jobRepository.DeleteJobs(jobsToDelete); err != nil { - logging.WithStacktrace(log, err).Error("failed to delete preempted jobs from Redis") + logging.WithStacktrace(ctx.Log, err).Error("failed to delete preempted jobs from Redis") } else { deleteErrorByJobId := armadamaps.MapKeys(deletionResult, func(job *api.Job) string { return job.Id }) for jobId := range preemptedApiJobsById { @@ -711,7 +704,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } } if err := q.usageRepository.UpdateClusterQueueResourceUsage(req.ClusterId, currentExecutorReport); err != nil { - logging.WithStacktrace(log, err).Errorf("failed to update cluster usage") + logging.WithStacktrace(ctx.Log, err).Errorf("failed to update cluster usage") } allocatedByQueueAndPriorityClassForPool = q.aggregateAllocationAcrossExecutor(reportsByExecutor, req.Pool) @@ -735,7 +728,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(log, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) continue } v := node.Labels[q.schedulingConfig.Preemption.NodeIdLabel] @@ -771,7 +764,7 @@ func (q *AggregatedQueueServer) getJobs(ctx context.Context, req *api.StreamingL } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(log, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) continue } podSpec.NodeName = node.Name @@ -880,22 +873,23 @@ func (q *AggregatedQueueServer) decompressJobOwnershipGroups(jobs []*api.Job) er } func (q *AggregatedQueueServer) decompressOwnershipGroups(compressedOwnershipGroups []byte) ([]string, error) { - decompressor, err := q.decompressorPool.BorrowObject(context.Background()) + decompressor, err := q.decompressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, fmt.Errorf("failed to borrow decompressior because %s", err) } - defer func(decompressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(decompressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := decompressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning decompressorPool to pool") } - }(q.decompressorPool, context.Background(), decompressor) + }(q.decompressorPool, armadacontext.Background(), decompressor) return compress.DecompressStringArray(compressedOwnershipGroups, decompressor.(compress.Decompressor)) } -func (q *AggregatedQueueServer) RenewLease(ctx context.Context, request *api.RenewLeaseRequest) (*api.IdList, error) { +func (q *AggregatedQueueServer) RenewLease(grpcCtx context.Context, request *api.RenewLeaseRequest) (*api.IdList, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, err.Error()) } @@ -903,7 +897,8 @@ func (q *AggregatedQueueServer) RenewLease(ctx context.Context, request *api.Ren return &api.IdList{Ids: renewed}, e } -func (q *AggregatedQueueServer) ReturnLease(ctx context.Context, request *api.ReturnLeaseRequest) (*types.Empty, error) { +func (q *AggregatedQueueServer) ReturnLease(grpcCtx context.Context, request *api.ReturnLeaseRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, err.Error()) } @@ -1002,7 +997,8 @@ func (q *AggregatedQueueServer) addAvoidNodeAffinity( return res[0].Error } -func (q *AggregatedQueueServer) ReportDone(ctx context.Context, idList *api.IdList) (*api.IdList, error) { +func (q *AggregatedQueueServer) ReportDone(grpcCtx context.Context, idList *api.IdList) (*api.IdList, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(q.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportDone] error: %s", err) } @@ -1027,7 +1023,7 @@ func (q *AggregatedQueueServer) ReportDone(ctx context.Context, idList *api.IdLi return &api.IdList{Ids: cleanedIds}, returnedError } -func (q *AggregatedQueueServer) reportLeaseReturned(ctx context.Context, leaseReturnRequest *api.ReturnLeaseRequest) error { +func (q *AggregatedQueueServer) reportLeaseReturned(ctx *armadacontext.Context, leaseReturnRequest *api.ReturnLeaseRequest) error { job, err := q.getJobById(leaseReturnRequest.JobId) if err != nil { return err diff --git a/internal/armada/server/lease_test.go b/internal/armada/server/lease_test.go index 7f3f8470491..554282c546a 100644 --- a/internal/armada/server/lease_test.go +++ b/internal/armada/server/lease_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -25,7 +25,7 @@ func TestAggregatedQueueServer_ReturnLeaseCallsRepositoryMethod(t *testing.T) { _, addJobsErr := mockJobRepository.AddJobs([]*api.Job{job}) assert.Nil(t, addJobsErr) - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -54,7 +54,7 @@ func TestAggregatedQueueServer_ReturnLeaseCallsSendsJobLeaseReturnedEvent(t *tes _, addJobsErr := mockJobRepository.AddJobs([]*api.Job{job}) assert.Nil(t, addJobsErr) - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, Reason: reason, @@ -84,7 +84,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesDeletesJob(t *tes assert.Nil(t, addJobsErr) for i := 0; i < maxRetries; i++ { - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -96,7 +96,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesDeletesJob(t *tes assert.Equal(t, jobId, mockJobRepository.returnLeaseArg2) } - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -125,7 +125,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesSendsJobFailedEve assert.Nil(t, addJobsErr) for i := 0; i < maxRetries; i++ { - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -136,7 +136,7 @@ func TestAggregatedQueueServer_ReturningLeaseMoreThanMaxRetriesSendsJobFailedEve fakeEventStore.events = []*api.EventMessage{} } - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, }) @@ -169,7 +169,7 @@ func TestAggregatedQueueServer_ReturningLease_IncrementsRetries(t *testing.T) { assert.Nil(t, addJobsErr) // Does not count towards retries if JobRunAttempted is false - _, err := aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err := aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: false, @@ -180,7 +180,7 @@ func TestAggregatedQueueServer_ReturningLease_IncrementsRetries(t *testing.T) { assert.Equal(t, 0, numberOfRetries) // Does count towards reties if JobRunAttempted is true - _, err = aggregatedQueueClient.ReturnLease(context.TODO(), &api.ReturnLeaseRequest{ + _, err = aggregatedQueueClient.ReturnLease(armadacontext.TODO(), &api.ReturnLeaseRequest{ ClusterId: clusterId, JobId: jobId, JobRunAttempted: true, @@ -452,7 +452,7 @@ type fakeEventStore struct { events []*api.EventMessage } -func (es *fakeEventStore) ReportEvents(_ context.Context, message []*api.EventMessage) error { +func (es *fakeEventStore) ReportEvents(_ *armadacontext.Context, message []*api.EventMessage) error { es.events = append(es.events, message...) return nil } @@ -469,14 +469,14 @@ func (repo *fakeSchedulingInfoRepository) UpdateClusterSchedulingInfo(report *ap type fakeExecutorRepository struct{} -func (f fakeExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (f fakeExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { return nil, nil } -func (f fakeExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (f fakeExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { return nil, nil } -func (f fakeExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (f fakeExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { return nil } diff --git a/internal/armada/server/reporting.go b/internal/armada/server/reporting.go index d3a5eae180b..73afc3d3c17 100644 --- a/internal/armada/server/reporting.go +++ b/internal/armada/server/reporting.go @@ -1,13 +1,13 @@ package server import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" ) @@ -27,7 +27,7 @@ func reportQueued(repository repository.EventStore, jobs []*api.Job) error { events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportQueued] error reporting events: %w", err) } @@ -52,7 +52,7 @@ func reportDuplicateDetected(repository repository.EventStore, results []*reposi events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportDuplicateDetected] error reporting events: %w", err) } @@ -77,7 +77,7 @@ func reportSubmitted(repository repository.EventStore, jobs []*api.Job) error { events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportSubmitted] error reporting events: %w", err) } @@ -106,7 +106,7 @@ func reportJobsLeased(repository repository.EventStore, jobs []*api.Job, cluster } } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { err = fmt.Errorf("[reportJobsLeased] error reporting events: %w", err) log.Error(err) @@ -128,7 +128,7 @@ func reportJobLeaseReturned(repository repository.EventStore, job *api.Job, leas return fmt.Errorf("error wrapping event: %w", err) } - err = repository.ReportEvents(context.Background(), []*api.EventMessage{event}) + err = repository.ReportEvents(armadacontext.Background(), []*api.EventMessage{event}) if err != nil { return fmt.Errorf("error reporting lease returned event: %w", err) } @@ -154,7 +154,7 @@ func reportJobsCancelling(repository repository.EventStore, requestorName string events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsCancelling] error reporting events: %w", err) } @@ -180,7 +180,7 @@ func reportJobsReprioritizing(repository repository.EventStore, requestorName st events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsReprioritizing] error reporting events: %w", err) } @@ -206,7 +206,7 @@ func reportJobsReprioritized(repository repository.EventStore, requestorName str events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsReprioritized] error reporting events: %w", err) } @@ -232,7 +232,7 @@ func reportJobsUpdated(repository repository.EventStore, requestorName string, j events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsUpdated] error reporting events: %w", err) } @@ -259,7 +259,7 @@ func reportJobsCancelled(repository repository.EventStore, requestorName string, events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportJobsCancelled] error reporting events: %w", err) } @@ -293,7 +293,7 @@ func reportFailed(repository repository.EventStore, clusterId string, jobFailure events = append(events, event) } - err := repository.ReportEvents(context.Background(), events) + err := repository.ReportEvents(armadacontext.Background(), events) if err != nil { return fmt.Errorf("[reportFailed] error reporting events: %w", err) } diff --git a/internal/armada/server/submit.go b/internal/armada/server/submit.go index ca444ff3099..c129fbb1da1 100644 --- a/internal/armada/server/submit.go +++ b/internal/armada/server/submit.go @@ -20,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" servervalidation "github.com/armadaproject/armada/internal/armada/validation" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/compress" @@ -62,7 +63,7 @@ func NewSubmitServer( NumTestsPerEvictionRun: 10, } - compressorPool := pool.NewObjectPool(context.Background(), pool.NewPooledObjectFactorySimple( + compressorPool := pool.NewObjectPool(armadacontext.Background(), pool.NewPooledObjectFactorySimple( func(context.Context) (interface{}, error) { return compress.NewZlibCompressor(512) }), &poolConfig) @@ -85,7 +86,8 @@ func (server *SubmitServer) Health(ctx context.Context, _ *types.Empty) (*api.He return &api.HealthCheckResponse{Status: api.HealthCheckResponse_SERVING}, nil } -func (server *SubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueInfoRequest) (*api.QueueInfo, error) { +func (server *SubmitServer) GetQueueInfo(grpcCtx context.Context, req *api.QueueInfoRequest) (*api.QueueInfo, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) q, err := server.queueRepository.GetQueue(req.Name) var expected *repository.ErrQueueNotFound if errors.Is(err, expected) { @@ -121,7 +123,7 @@ func (server *SubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueInfo }, nil } -func (server *SubmitServer) GetQueue(ctx context.Context, req *api.QueueGetRequest) (*api.Queue, error) { +func (server *SubmitServer) GetQueue(grpcCtx context.Context, req *api.QueueGetRequest) (*api.Queue, error) { queue, err := server.queueRepository.GetQueue(req.Name) var e *repository.ErrQueueNotFound if errors.As(err, &e) { @@ -132,7 +134,8 @@ func (server *SubmitServer) GetQueue(ctx context.Context, req *api.QueueGetReque return queue.ToAPI(), nil } -func (server *SubmitServer) CreateQueue(ctx context.Context, request *api.Queue) (*types.Empty, error) { +func (server *SubmitServer) CreateQueue(grpcCtx context.Context, request *api.Queue) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.CreateQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -162,9 +165,9 @@ func (server *SubmitServer) CreateQueue(ctx context.Context, request *api.Queue) return &types.Empty{}, nil } -func (server *SubmitServer) CreateQueues(ctx context.Context, request *api.QueueList) (*api.BatchQueueCreateResponse, error) { +func (server *SubmitServer) CreateQueues(grpcCtx context.Context, request *api.QueueList) (*api.BatchQueueCreateResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var failedQueues []*api.QueueCreateResponse - // Create a queue for each element of the request body and return the failures. for _, queue := range request.Queues { _, err := server.CreateQueue(ctx, queue) @@ -181,7 +184,8 @@ func (server *SubmitServer) CreateQueues(ctx context.Context, request *api.Queue }, nil } -func (server *SubmitServer) UpdateQueue(ctx context.Context, request *api.Queue) (*types.Empty, error) { +func (server *SubmitServer) UpdateQueue(grpcCtx context.Context, request *api.Queue) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.CreateQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -206,7 +210,8 @@ func (server *SubmitServer) UpdateQueue(ctx context.Context, request *api.Queue) return &types.Empty{}, nil } -func (server *SubmitServer) UpdateQueues(ctx context.Context, request *api.QueueList) (*api.BatchQueueUpdateResponse, error) { +func (server *SubmitServer) UpdateQueues(grpcCtx context.Context, request *api.QueueList) (*api.BatchQueueUpdateResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var failedQueues []*api.QueueUpdateResponse // Create a queue for each element of the request body and return the failures. @@ -225,7 +230,8 @@ func (server *SubmitServer) UpdateQueues(ctx context.Context, request *api.Queue }, nil } -func (server *SubmitServer) DeleteQueue(ctx context.Context, request *api.QueueDeleteRequest) (*types.Empty, error) { +func (server *SubmitServer) DeleteQueue(grpcCtx context.Context, request *api.QueueDeleteRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := checkPermission(server.permissions, ctx, permissions.DeleteQueue) var ep *ErrUnauthorized if errors.As(err, &ep) { @@ -250,7 +256,8 @@ func (server *SubmitServer) DeleteQueue(ctx context.Context, request *api.QueueD return &types.Empty{}, nil } -func (server *SubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { +func (server *SubmitServer) SubmitJobs(grpcCtx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) principal := authorization.GetPrincipal(ctx) jobs, e := server.createJobs(req, principal.GetName(), principal.GetGroupNames()) @@ -404,7 +411,8 @@ func (server *SubmitServer) countQueuedJobs(q queue.Queue) (int64, error) { // CancelJobs cancels jobs identified by the request. // If the request contains a job ID, only the job with that ID is cancelled. // If the request contains a queue name and a job set ID, all jobs matching those are cancelled. -func (server *SubmitServer) CancelJobs(ctx context.Context, request *api.JobCancelRequest) (*api.CancellationResult, error) { +func (server *SubmitServer) CancelJobs(grpcCtx context.Context, request *api.JobCancelRequest) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if request.JobId != "" { return server.cancelJobsById(ctx, request.JobId, request.Reason) } else if request.JobSetId != "" && request.Queue != "" { @@ -413,7 +421,8 @@ func (server *SubmitServer) CancelJobs(ctx context.Context, request *api.JobCanc return nil, status.Errorf(codes.InvalidArgument, "[CancelJobs] specify either job ID or both queue name and job set ID") } -func (server *SubmitServer) CancelJobSet(ctx context.Context, request *api.JobSetCancelRequest) (*types.Empty, error) { +func (server *SubmitServer) CancelJobSet(grpcCtx context.Context, request *api.JobSetCancelRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := servervalidation.ValidateJobSetFilter(request.Filter) if err != nil { return nil, err @@ -444,7 +453,7 @@ func createJobSetFilter(filter *api.JobSetFilter) *repository.JobSetFilter { } // cancels a job with a given ID -func (server *SubmitServer) cancelJobsById(ctx context.Context, jobId string, reason string) (*api.CancellationResult, error) { +func (server *SubmitServer) cancelJobsById(ctx *armadacontext.Context, jobId string, reason string) (*api.CancellationResult, error) { jobs, err := server.jobRepository.GetExistingJobsByIds([]string{jobId}) if err != nil { return nil, status.Errorf(codes.Unavailable, "[cancelJobsById] error getting job with ID %s: %s", jobId, err) @@ -466,7 +475,7 @@ func (server *SubmitServer) cancelJobsById(ctx context.Context, jobId string, re // cancels all jobs part of a particular job set and queue func (server *SubmitServer) cancelJobsByQueueAndSet( - ctx context.Context, + ctx *armadacontext.Context, queue string, jobSetId string, filter *repository.JobSetFilter, @@ -509,7 +518,7 @@ func (server *SubmitServer) cancelJobsByQueueAndSet( return &api.CancellationResult{CancelledIds: cancelledIds}, nil } -func (server *SubmitServer) cancelJobs(ctx context.Context, jobs []*api.Job, reason string) (*api.CancellationResult, error) { +func (server *SubmitServer) cancelJobs(ctx *armadacontext.Context, jobs []*api.Job, reason string) (*api.CancellationResult, error) { principal := authorization.GetPrincipal(ctx) err := server.checkCancelPerms(ctx, jobs) @@ -551,7 +560,7 @@ func (server *SubmitServer) cancelJobs(ctx context.Context, jobs []*api.Job, rea return &api.CancellationResult{CancelledIds: cancelledIds}, nil } -func (server *SubmitServer) checkCancelPerms(ctx context.Context, jobs []*api.Job) error { +func (server *SubmitServer) checkCancelPerms(ctx *armadacontext.Context, jobs []*api.Job) error { queueNames := make(map[string]struct{}) for _, job := range jobs { queueNames[job.Queue] = struct{}{} @@ -581,7 +590,8 @@ func (server *SubmitServer) checkCancelPerms(ctx context.Context, jobs []*api.Jo // ReprioritizeJobs updates the priority of one of more jobs. // Returns a map from job ID to any error (or nil if the call succeeded). -func (server *SubmitServer) ReprioritizeJobs(ctx context.Context, request *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { +func (server *SubmitServer) ReprioritizeJobs(grpcCtx context.Context, request *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) var jobs []*api.Job if len(request.JobIds) > 0 { existingJobs, err := server.jobRepository.GetExistingJobsByIds(request.JobIds) @@ -674,7 +684,7 @@ func (server *SubmitServer) reportReprioritizedJobEvents(reprioritizedJobs []*ap return nil } -func (server *SubmitServer) checkReprioritizePerms(ctx context.Context, jobs []*api.Job) error { +func (server *SubmitServer) checkReprioritizePerms(ctx *armadacontext.Context, jobs []*api.Job) error { queueNames := make(map[string]struct{}) for _, job := range jobs { queueNames[job.Queue] = struct{}{} @@ -702,7 +712,7 @@ func (server *SubmitServer) checkReprioritizePerms(ctx context.Context, jobs []* return nil } -func (server *SubmitServer) getQueueOrCreate(ctx context.Context, queueName string) (*queue.Queue, error) { +func (server *SubmitServer) getQueueOrCreate(ctx *armadacontext.Context, queueName string) (*queue.Queue, error) { q, e := server.queueRepository.GetQueue(queueName) if e == nil { return &q, nil @@ -753,16 +763,16 @@ func (server *SubmitServer) createJobs(request *api.JobSubmitRequest, owner stri func (server *SubmitServer) createJobsObjects(request *api.JobSubmitRequest, owner string, ownershipGroups []string, getTime func() time.Time, getUlid func() string, ) ([]*api.Job, error) { - compressor, err := server.compressorPool.BorrowObject(context.Background()) + compressor, err := server.compressorPool.BorrowObject(armadacontext.Background()) if err != nil { return nil, err } - defer func(compressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(compressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := compressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning compressor to pool") } - }(server.compressorPool, context.Background(), compressor) + }(server.compressorPool, armadacontext.Background(), compressor) compressedOwnershipGroups, err := compress.CompressStringArray(ownershipGroups, compressor.(compress.Compressor)) if err != nil { return nil, err diff --git a/internal/armada/server/submit_from_log.go b/internal/armada/server/submit_from_log.go index 13acbf9904a..90b5ece3553 100644 --- a/internal/armada/server/submit_from_log.go +++ b/internal/armada/server/submit_from_log.go @@ -7,19 +7,17 @@ import ( "time" "github.com/apache/pulsar-client-go/pulsar" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-multierror" pool "github.com/jolestar/go-commons-pool" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/logging" - "github.com/armadaproject/armada/internal/common/pulsarutils/pulsarrequestid" - "github.com/armadaproject/armada/internal/common/requestid" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api" @@ -38,7 +36,7 @@ type SubmitFromLog struct { } // Run the service that reads from Pulsar and updates Armada until the provided context is cancelled. -func (srv *SubmitFromLog) Run(ctx context.Context) error { +func (srv *SubmitFromLog) Run(ctx *armadacontext.Context) error { // Get the configured logger, or the standard logger if none is provided. log := srv.getLogger() log.Info("service started") @@ -95,7 +93,7 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, 10*time.Second) msg, err := srv.Consumer.Receive(ctxWithTimeout) cancel() if errors.Is(err, context.DeadlineExceeded) { @@ -121,29 +119,18 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { lastPublishTime = msg.PublishTime() numReceived++ - // Incoming gRPC requests are annotated with a unique id, - // which is included with the corresponding Pulsar message. - requestId := pulsarrequestid.FromMessageOrMissing(msg) - - // Put the requestId into a message-specific context and logger, - // which are passed on to sub-functions. - messageCtx, ok := requestid.AddToIncomingContext(ctx, requestId) - if !ok { - messageCtx = ctx - } - messageLogger := log.WithFields(logrus.Fields{"messageId": msg.ID(), requestid.MetadataKey: requestId}) - ctxWithLogger := ctxlogrus.ToContext(messageCtx, messageLogger) + ctxWithLogger := armadacontext.WithLogField(ctx, "messageId", msg.ID()) // Unmarshal and validate the message. sequence, err := eventutil.UnmarshalEventSequence(ctxWithLogger, msg.Payload()) if err != nil { srv.ack(ctx, msg) - logging.WithStacktrace(messageLogger, err).Warnf("processing message failed; ignoring") + logging.WithStacktrace(ctxWithLogger.Log, err).Warnf("processing message failed; ignoring") numErrored++ break } - messageLogger.WithField("numEvents", len(sequence.Events)).Info("processing sequence") + ctxWithLogger.Log.WithField("numEvents", len(sequence.Events)).Info("processing sequence") // TODO: Improve retry logic. srv.ProcessSequence(ctxWithLogger, sequence) srv.ack(ctx, msg) @@ -155,9 +142,7 @@ func (srv *SubmitFromLog) Run(ctx context.Context) error { // For efficiency, we may process several events at a time. // To maintain ordering, we only do so for subsequences of consecutive events of equal type. // The returned bool indicates if the corresponding Pulsar message should be ack'd or not. -func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadaevents.EventSequence) bool { - log := ctxlogrus.Extract(ctx) - +func (srv *SubmitFromLog) ProcessSequence(ctx *armadacontext.Context, sequence *armadaevents.EventSequence) bool { // Sub-functions should always increment the events index unless they experience a transient error. // However, if a permanent error is mis-categorised as transient, we may get stuck forever. // To avoid that issue, we return immediately if timeout time has passed @@ -170,11 +155,11 @@ func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadae for i < len(sequence.Events) && time.Since(lastProgress) < timeout { j, err := srv.ProcessSubSequence(ctx, i, sequence) if err != nil { - logging.WithStacktrace(log, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") + logging.WithStacktrace(ctx.Log, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") } if j == i { - log.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") + ctx.Log.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") // We should only get here if a transient error occurs. // Sleep for a bit before retrying. @@ -200,7 +185,7 @@ func (srv *SubmitFromLog) ProcessSequence(ctx context.Context, sequence *armadae // Events are processed by calling into the embedded srv.SubmitServer. // // Not all events are handled by this processor since the legacy scheduler writes some transitions directly to the db. -func (srv *SubmitFromLog) ProcessSubSequence(ctx context.Context, i int, sequence *armadaevents.EventSequence) (j int, err error) { +func (srv *SubmitFromLog) ProcessSubSequence(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) (j int, err error) { j = i // Initially, the next event to be processed is i. if i < 0 || i >= len(sequence.Events) { err = &armadaerrors.ErrInvalidArgument{ @@ -272,7 +257,7 @@ func (srv *SubmitFromLog) ProcessSubSequence(ctx context.Context, i int, sequenc // collectJobSubmitEvents (and the corresponding functions for other types below) // return a slice of events starting at index i in the sequence with equal type. -func collectJobSubmitEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.SubmitJob { +func collectJobSubmitEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.SubmitJob { result := make([]*armadaevents.SubmitJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_SubmitJob); ok { @@ -284,7 +269,7 @@ func collectJobSubmitEvents(ctx context.Context, i int, sequence *armadaevents.E return result } -func collectCancelJobEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJob { +func collectCancelJobEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJob { result := make([]*armadaevents.CancelJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_CancelJob); ok { @@ -296,7 +281,7 @@ func collectCancelJobEvents(ctx context.Context, i int, sequence *armadaevents.E return result } -func collectCancelJobSetEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJobSet { +func collectCancelJobSetEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.CancelJobSet { result := make([]*armadaevents.CancelJobSet, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_CancelJobSet); ok { @@ -308,7 +293,7 @@ func collectCancelJobSetEvents(ctx context.Context, i int, sequence *armadaevent return result } -func collectReprioritiseJobEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJob { +func collectReprioritiseJobEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJob { result := make([]*armadaevents.ReprioritiseJob, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_ReprioritiseJob); ok { @@ -320,7 +305,7 @@ func collectReprioritiseJobEvents(ctx context.Context, i int, sequence *armadaev return result } -func collectReprioritiseJobSetEvents(ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJobSet { +func collectReprioritiseJobSetEvents(ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.ReprioritiseJobSet { result := make([]*armadaevents.ReprioritiseJobSet, 0) for j := i; j < len(sequence.Events); j++ { if e, ok := sequence.Events[j].Event.(*armadaevents.EventSequence_Event_ReprioritiseJobSet); ok { @@ -332,7 +317,7 @@ func collectReprioritiseJobSetEvents(ctx context.Context, i int, sequence *armad return result } -func collectEvents[T any](ctx context.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.EventSequence_Event { +func collectEvents[T any](ctx *armadacontext.Context, i int, sequence *armadaevents.EventSequence) []*armadaevents.EventSequence_Event { events := make([]*armadaevents.EventSequence_Event, 0) for j := i; j < len(sequence.Events); j++ { if _, ok := sequence.Events[j].Event.(T); ok { @@ -359,7 +344,7 @@ func (srv *SubmitFromLog) getLogger() *logrus.Entry { // Specifically, events are not processed if writing to the database results in a network-related error. // For any other error, the jobs are marked as failed and the events are considered to have been processed. func (srv *SubmitFromLog) SubmitJobs( - ctx context.Context, + ctx *armadacontext.Context, userId string, groups []string, queueName string, @@ -376,16 +361,16 @@ func (srv *SubmitFromLog) SubmitJobs( } log := srv.getLogger() - compressor, err := srv.SubmitServer.compressorPool.BorrowObject(context.Background()) + compressor, err := srv.SubmitServer.compressorPool.BorrowObject(armadacontext.Background()) if err != nil { return false, err } - defer func(compressorPool *pool.ObjectPool, ctx context.Context, object interface{}) { + defer func(compressorPool *pool.ObjectPool, ctx *armadacontext.Context, object interface{}) { err := compressorPool.ReturnObject(ctx, object) if err != nil { log.WithError(err).Errorf("Error returning compressor to pool") } - }(srv.SubmitServer.compressorPool, context.Background(), compressor) + }(srv.SubmitServer.compressorPool, armadacontext.Background(), compressor) compressedOwnershipGroups, err := compress.CompressStringArray(groups, compressor.(compress.Compressor)) if err != nil { @@ -455,7 +440,7 @@ type CancelJobPayload struct { } // CancelJobs cancels all jobs specified by the provided events in a single operation. -func (srv *SubmitFromLog) CancelJobs(ctx context.Context, userId string, es []*armadaevents.CancelJob) (bool, error) { +func (srv *SubmitFromLog) CancelJobs(ctx *armadacontext.Context, userId string, es []*armadaevents.CancelJob) (bool, error) { cancelJobPayloads := make([]*CancelJobPayload, len(es)) for i, e := range es { id, err := armadaevents.UlidStringFromProtoUuid(e.JobId) @@ -475,7 +460,7 @@ func (srv *SubmitFromLog) CancelJobs(ctx context.Context, userId string, es []*a // Because event sequences are specific to queue and job set, all CancelJobSet events in a sequence are equivalent, // and we only need to call CancelJobSet once. func (srv *SubmitFromLog) CancelJobSets( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -489,7 +474,7 @@ func (srv *SubmitFromLog) CancelJobSets( return srv.CancelJobSet(ctx, userId, queueName, jobSetName, reason) } -func (srv *SubmitFromLog) CancelJobSet(ctx context.Context, userId string, queueName string, jobSetName string, reason string) (bool, error) { +func (srv *SubmitFromLog) CancelJobSet(ctx *armadacontext.Context, userId string, queueName string, jobSetName string, reason string) (bool, error) { jobIds, err := srv.SubmitServer.jobRepository.GetActiveJobIds(queueName, jobSetName) if armadaerrors.IsNetworkError(err) { return false, err @@ -505,7 +490,7 @@ func (srv *SubmitFromLog) CancelJobSet(ctx context.Context, userId string, queue return srv.BatchedCancelJobsById(ctx, userId, cancelJobPayloads) } -func (srv *SubmitFromLog) BatchedCancelJobsById(ctx context.Context, userId string, cancelJobPayloads []*CancelJobPayload) (bool, error) { +func (srv *SubmitFromLog) BatchedCancelJobsById(ctx *armadacontext.Context, userId string, cancelJobPayloads []*CancelJobPayload) (bool, error) { // Split IDs into batches and process one batch at a time. // To reduce the number of jobs stored in memory. // @@ -538,7 +523,7 @@ type CancelledJobPayload struct { } // CancelJobsById cancels all jobs with the specified ids. -func (srv *SubmitFromLog) CancelJobsById(ctx context.Context, userId string, cancelJobPayloads []*CancelJobPayload) ([]string, error) { +func (srv *SubmitFromLog) CancelJobsById(ctx *armadacontext.Context, userId string, cancelJobPayloads []*CancelJobPayload) ([]string, error) { jobIdReasonMap := make(map[string]string) jobIds := util.Map(cancelJobPayloads, func(payload *CancelJobPayload) string { jobIdReasonMap[payload.JobId] = payload.Reason @@ -588,7 +573,7 @@ func (srv *SubmitFromLog) CancelJobsById(ctx context.Context, userId string, can } // ReprioritizeJobs updates the priority of one of more jobs. -func (srv *SubmitFromLog) ReprioritizeJobs(ctx context.Context, userId string, es []*armadaevents.ReprioritiseJob) (bool, error) { +func (srv *SubmitFromLog) ReprioritizeJobs(ctx *armadacontext.Context, userId string, es []*armadaevents.ReprioritiseJob) (bool, error) { if len(es) == 0 { return true, nil } @@ -635,7 +620,7 @@ func (srv *SubmitFromLog) ReprioritizeJobs(ctx context.Context, userId string, e return true, nil } -func (srv *SubmitFromLog) DeleteFailedJobs(ctx context.Context, es []*armadaevents.EventSequence_Event) (bool, error) { +func (srv *SubmitFromLog) DeleteFailedJobs(ctx *armadacontext.Context, es []*armadaevents.EventSequence_Event) (bool, error) { jobIdsToDelete := make([]string, 0, len(es)) for _, event := range es { jobErrors := event.GetJobErrors() @@ -664,7 +649,7 @@ func (srv *SubmitFromLog) DeleteFailedJobs(ctx context.Context, es []*armadaeven } // UpdateJobStartTimes records the start time (in Redis) of one of more jobs. -func (srv *SubmitFromLog) UpdateJobStartTimes(ctx context.Context, es []*armadaevents.EventSequence_Event) (bool, error) { +func (srv *SubmitFromLog) UpdateJobStartTimes(ctx *armadacontext.Context, es []*armadaevents.EventSequence_Event) (bool, error) { jobStartsInfos := make([]*repository.JobStartInfo, 0, len(es)) for _, event := range es { jobRun := event.GetJobRunRunning() @@ -713,7 +698,7 @@ func (srv *SubmitFromLog) UpdateJobStartTimes(ctx context.Context, es []*armadae // Since repeating this operation is safe (setting the priority is idempotent), // the bool indicating if events were processed is set to false if any job set failed. func (srv *SubmitFromLog) ReprioritizeJobSets( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -730,7 +715,7 @@ func (srv *SubmitFromLog) ReprioritizeJobSets( } func (srv *SubmitFromLog) ReprioritizeJobSet( - ctx context.Context, + ctx *armadacontext.Context, userId string, queueName string, jobSetName string, @@ -767,7 +752,7 @@ func (srv *SubmitFromLog) ReprioritizeJobSet( return true, nil } -func (srv *SubmitFromLog) ack(ctx context.Context, msg pulsar.Message) { +func (srv *SubmitFromLog) ack(ctx *armadacontext.Context, msg pulsar.Message) { util.RetryUntilSuccess( ctx, func() error { diff --git a/internal/armada/server/submit_from_log_test.go b/internal/armada/server/submit_from_log_test.go index c3479888d06..45368bfe7e2 100644 --- a/internal/armada/server/submit_from_log_test.go +++ b/internal/armada/server/submit_from_log_test.go @@ -1,13 +1,13 @@ package server import ( - ctx "context" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/testfixtures" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -42,7 +42,7 @@ func TestUpdateJobStartTimes(t *testing.T) { }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.NoError(t, err) assert.True(t, ok) @@ -59,7 +59,7 @@ func TestUpdateJobStartTimes_NonExistentJob(t *testing.T) { jobRepository: jobRepo, }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.Nil(t, err) assert.True(t, ok) @@ -75,7 +75,7 @@ func TestUpdateJobStartTimes_RedisError(t *testing.T) { jobRepository: jobRepo, }, } - ok, err := s.UpdateJobStartTimes(ctx.Background(), events) + ok, err := s.UpdateJobStartTimes(armadacontext.Background(), events) assert.Error(t, err) assert.False(t, ok) diff --git a/internal/armada/server/submit_to_log.go b/internal/armada/server/submit_to_log.go index cf4b12ceca2..aaaee8a35e2 100644 --- a/internal/armada/server/submit_to_log.go +++ b/internal/armada/server/submit_to_log.go @@ -20,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/validation" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" @@ -68,7 +69,8 @@ type PulsarSubmitServer struct { IgnoreJobSubmitChecks bool } -func (srv *PulsarSubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { +func (srv *PulsarSubmitServer) SubmitJobs(grpcCtx context.Context, req *api.JobSubmitRequest) (*api.JobSubmitResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) userId, groups, err := srv.Authorize(ctx, req.Queue, permissions.SubmitAnyJobs, queue.PermissionVerbSubmit) if err != nil { return nil, err @@ -240,7 +242,9 @@ func (srv *PulsarSubmitServer) SubmitJobs(ctx context.Context, req *api.JobSubmi return &api.JobSubmitResponse{JobResponseItems: responses}, nil } -func (srv *PulsarSubmitServer) CancelJobs(ctx context.Context, req *api.JobCancelRequest) (*api.CancellationResult, error) { +func (srv *PulsarSubmitServer) CancelJobs(grpcCtx context.Context, req *api.JobCancelRequest) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) + // separate code path for multiple jobs if len(req.JobIds) > 0 { return srv.cancelJobsByIdsQueueJobset(ctx, req.JobIds, req.Queue, req.JobSetId, req.Reason) @@ -328,7 +332,8 @@ func (srv *PulsarSubmitServer) CancelJobs(ctx context.Context, req *api.JobCance } // Assumes all Job IDs are in the queue and job set provided -func (srv *PulsarSubmitServer) cancelJobsByIdsQueueJobset(ctx context.Context, jobIds []string, q, jobSet string, reason string) (*api.CancellationResult, error) { +func (srv *PulsarSubmitServer) cancelJobsByIdsQueueJobset(grpcCtx context.Context, jobIds []string, q, jobSet string, reason string) (*api.CancellationResult, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if q == "" { return nil, &armadaerrors.ErrInvalidArgument{ Name: "Queue", @@ -390,7 +395,8 @@ func eventSequenceForJobIds(jobIds []string, q, jobSet, userId string, groups [] return sequence, validIds } -func (srv *PulsarSubmitServer) CancelJobSet(ctx context.Context, req *api.JobSetCancelRequest) (*types.Empty, error) { +func (srv *PulsarSubmitServer) CancelJobSet(grpcCtx context.Context, req *api.JobSetCancelRequest) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if req.Queue == "" { return nil, &armadaerrors.ErrInvalidArgument{ Name: "Queue", @@ -492,7 +498,9 @@ func (srv *PulsarSubmitServer) CancelJobSet(ctx context.Context, req *api.JobSet return &types.Empty{}, err } -func (srv *PulsarSubmitServer) ReprioritizeJobs(ctx context.Context, req *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { +func (srv *PulsarSubmitServer) ReprioritizeJobs(grpcCtx context.Context, req *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) + // If either queue or jobSetId is missing, we get the job set and queue associated // with the first job id in the request. // @@ -612,7 +620,7 @@ func (srv *PulsarSubmitServer) ReprioritizeJobs(ctx context.Context, req *api.Jo // Checks that the user has either anyPerm (e.g., permissions.SubmitAnyJobs) or perm (e.g., PermissionVerbSubmit) for this queue. // Returns the userId and groups extracted from the context. func (srv *PulsarSubmitServer) Authorize( - ctx context.Context, + ctx *armadacontext.Context, queueName string, anyPerm permission.Permission, perm queue.PermissionVerb, @@ -694,7 +702,7 @@ func (srv *PulsarSubmitServer) GetQueueInfo(ctx context.Context, req *api.QueueI } // PublishToPulsar sends pulsar messages async -func (srv *PulsarSubmitServer) publishToPulsar(ctx context.Context, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { +func (srv *PulsarSubmitServer) publishToPulsar(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than srv.MaxAllowedMessageSize. sequences = eventutil.CompactEventSequences(sequences) @@ -714,7 +722,7 @@ func jobKey(j *api.Job) string { // getOriginalJobIds returns the mapping between jobId and originalJobId. If the job (or more specifically the clientId // on the job) has not been seen before then jobId -> jobId. If the job has been seen before then jobId -> originalJobId // Note that if srv.KVStore is nil then this function simply returns jobId -> jobId -func (srv *PulsarSubmitServer) getOriginalJobIds(ctx context.Context, apiJobs []*api.Job) (map[string]string, error) { +func (srv *PulsarSubmitServer) getOriginalJobIds(ctx *armadacontext.Context, apiJobs []*api.Job) (map[string]string, error) { // Default is the current id ret := make(map[string]string, len(apiJobs)) for _, apiJob := range apiJobs { @@ -753,7 +761,7 @@ func (srv *PulsarSubmitServer) getOriginalJobIds(ctx context.Context, apiJobs [] return ret, nil } -func (srv *PulsarSubmitServer) storeOriginalJobIds(ctx context.Context, apiJobs []*api.Job) error { +func (srv *PulsarSubmitServer) storeOriginalJobIds(ctx *armadacontext.Context, apiJobs []*api.Job) error { if srv.KVStore == nil { return nil } diff --git a/internal/armada/server/usage.go b/internal/armada/server/usage.go index 92fe54abd45..9c6e1e7800e 100644 --- a/internal/armada/server/usage.go +++ b/internal/armada/server/usage.go @@ -12,6 +12,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/armada/repository" "github.com/armadaproject/armada/internal/armada/scheduling" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -41,7 +42,8 @@ func NewUsageServer( } } -func (s *UsageServer) ReportUsage(ctx context.Context, report *api.ClusterUsageReport) (*types.Empty, error) { +func (s *UsageServer) ReportUsage(grpcCtx context.Context, report *api.ClusterUsageReport) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) if err := checkPermission(s.permissions, ctx, permissions.ExecuteJobs); err != nil { return nil, status.Errorf(codes.PermissionDenied, "[ReportUsage] error: %s", err) } diff --git a/internal/armada/server/usage_test.go b/internal/armada/server/usage_test.go index 6464b154880..8f1fa88b30b 100644 --- a/internal/armada/server/usage_test.go +++ b/internal/armada/server/usage_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/armada/repository" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client/queue" @@ -26,14 +26,14 @@ func TestUsageServer_ReportUsage(t *testing.T) { err := s.queueRepository.CreateQueue(queue.Queue{Name: "q1", PriorityFactor: 1}) assert.Nil(t, err) - _, err = s.ReportUsage(context.Background(), oneQueueReport(now, cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now, cpu, memory)) assert.Nil(t, err) priority, err := s.usageRepository.GetClusterPriority("clusterA") assert.Nil(t, err) assert.Equal(t, 10.0, priority["q1"], "Priority should be updated for the new cluster.") - _, err = s.ReportUsage(context.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) assert.Nil(t, err) priority, err = s.usageRepository.GetClusterPriority("clusterA") @@ -51,14 +51,14 @@ func TestUsageServer_ReportUsageWithDefinedScarcity(t *testing.T) { err := s.queueRepository.CreateQueue(queue.Queue{Name: "q1", PriorityFactor: 1}) assert.Nil(t, err) - _, err = s.ReportUsage(context.Background(), oneQueueReport(now, cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now, cpu, memory)) assert.Nil(t, err) priority, err := s.usageRepository.GetClusterPriority("clusterA") assert.Nil(t, err) assert.Equal(t, 5.0, priority["q1"], "Priority should be updated for the new cluster.") - _, err = s.ReportUsage(context.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) + _, err = s.ReportUsage(armadacontext.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) assert.Nil(t, err) priority, err = s.usageRepository.GetClusterPriority("clusterA") diff --git a/internal/armadactl/analyze.go b/internal/armadactl/analyze.go index de9d29fb5dc..650c0861684 100644 --- a/internal/armadactl/analyze.go +++ b/internal/armadactl/analyze.go @@ -1,11 +1,11 @@ package armadactl import ( - "context" "encoding/json" "fmt" "reflect" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" "github.com/armadaproject/armada/pkg/client/domain" @@ -17,7 +17,7 @@ func (a *App) Analyze(queue string, jobSetId string) error { events := map[string][]*api.Event{} var jobState *domain.WatchContext - client.WatchJobSet(ec, queue, jobSetId, false, true, false, false, context.Background(), func(state *domain.WatchContext, e api.Event) bool { + client.WatchJobSet(ec, queue, jobSetId, false, true, false, false, armadacontext.Background(), func(state *domain.WatchContext, e api.Event) bool { events[e.GetJobId()] = append(events[e.GetJobId()], &e) jobState = state return false diff --git a/internal/armadactl/kube.go b/internal/armadactl/kube.go index ef466f7e6b8..d9b63a0399a 100644 --- a/internal/armadactl/kube.go +++ b/internal/armadactl/kube.go @@ -1,10 +1,10 @@ package armadactl import ( - "context" "fmt" "strings" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" ) @@ -14,7 +14,7 @@ import ( func (a *App) Kube(jobId string, queueName string, jobSetId string, podNumber int, args []string) error { verb := strings.Join(args, " ") return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - state := client.GetJobSetState(c, queueName, jobSetId, context.Background(), true, false, false) + state := client.GetJobSetState(c, queueName, jobSetId, armadacontext.Background(), true, false, false) jobInfo := state.GetJobInfo(jobId) if jobInfo == nil { diff --git a/internal/armadactl/resources.go b/internal/armadactl/resources.go index 4cf4faa653c..8a7f018bc0d 100644 --- a/internal/armadactl/resources.go +++ b/internal/armadactl/resources.go @@ -1,9 +1,9 @@ package armadactl import ( - "context" "fmt" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" ) @@ -11,7 +11,7 @@ import ( // Resources prints the resources used by the jobs in job set with ID jobSetId in the given queue. func (a *App) Resources(queueName string, jobSetId string) error { return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - state := client.GetJobSetState(c, queueName, jobSetId, context.Background(), true, false, false) + state := client.GetJobSetState(c, queueName, jobSetId, armadacontext.Background(), true, false, false) for _, job := range state.GetCurrentState() { fmt.Fprintf(a.Out, "Job ID: %v, maximum used resources: %v\n", job.Job.Id, job.MaxUsedResources) diff --git a/internal/armadactl/watch.go b/internal/armadactl/watch.go index fd0d842d5cf..872a01388c8 100644 --- a/internal/armadactl/watch.go +++ b/internal/armadactl/watch.go @@ -1,12 +1,12 @@ package armadactl import ( - "context" "encoding/json" "fmt" "reflect" "time" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/client" "github.com/armadaproject/armada/pkg/client/domain" @@ -16,7 +16,7 @@ import ( func (a *App) Watch(queue string, jobSetId string, raw bool, exitOnInactive bool, forceNewEvents bool, forceLegacyEvents bool) error { fmt.Fprintf(a.Out, "Watching job set %s\n", jobSetId) return client.WithEventClient(a.Params.ApiConnectionDetails, func(c api.EventClient) error { - client.WatchJobSet(c, queue, jobSetId, true, true, forceNewEvents, forceLegacyEvents, context.Background(), func(state *domain.WatchContext, event api.Event) bool { + client.WatchJobSet(c, queue, jobSetId, true, true, forceNewEvents, forceLegacyEvents, armadacontext.Background(), func(state *domain.WatchContext, event api.Event) bool { if raw { data, err := json.Marshal(event) if err != nil { diff --git a/internal/binoculars/server/binoculars.go b/internal/binoculars/server/binoculars.go index 4497573a04d..0a08237058f 100644 --- a/internal/binoculars/server/binoculars.go +++ b/internal/binoculars/server/binoculars.go @@ -8,6 +8,7 @@ import ( "github.com/armadaproject/armada/internal/binoculars/service" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/pkg/api/binoculars" ) @@ -27,7 +28,7 @@ func NewBinocularsServer(logService service.LogService, cordonService service.Co func (b *BinocularsServer) Logs(ctx context.Context, request *binoculars.LogRequest) (*binoculars.LogResponse, error) { principal := authorization.GetPrincipal(ctx) - logLines, err := b.logService.GetLogs(ctx, &service.LogParams{ + logLines, err := b.logService.GetLogs(armadacontext.FromGrpcCtx(ctx), &service.LogParams{ Principal: principal, Namespace: request.PodNamespace, PodName: common.PodNamePrefix + request.JobId + "-" + strconv.Itoa(int(request.PodNumber)), @@ -42,7 +43,7 @@ func (b *BinocularsServer) Logs(ctx context.Context, request *binoculars.LogRequ } func (b *BinocularsServer) Cordon(ctx context.Context, request *binoculars.CordonRequest) (*types.Empty, error) { - err := b.cordonService.CordonNode(ctx, request) + err := b.cordonService.CordonNode(armadacontext.FromGrpcCtx(ctx), request) if err != nil { return nil, err } diff --git a/internal/binoculars/service/cordon.go b/internal/binoculars/service/cordon.go index 8d850bca8ec..584da9bf4ca 100644 --- a/internal/binoculars/service/cordon.go +++ b/internal/binoculars/service/cordon.go @@ -1,7 +1,6 @@ package service import ( - "context" "encoding/json" "fmt" "strings" @@ -14,6 +13,7 @@ import ( "github.com/armadaproject/armada/internal/armada/permissions" "github.com/armadaproject/armada/internal/binoculars/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/internal/common/cluster" @@ -23,7 +23,7 @@ import ( const userTemplate = "" type CordonService interface { - CordonNode(ctx context.Context, request *binoculars.CordonRequest) error + CordonNode(ctx *armadacontext.Context, request *binoculars.CordonRequest) error } type KubernetesCordonService struct { @@ -44,7 +44,7 @@ func NewKubernetesCordonService( } } -func (c *KubernetesCordonService) CordonNode(ctx context.Context, request *binoculars.CordonRequest) error { +func (c *KubernetesCordonService) CordonNode(ctx *armadacontext.Context, request *binoculars.CordonRequest) error { err := checkPermission(c.permissionChecker, ctx, permissions.CordonNodes) if err != nil { return status.Errorf(codes.PermissionDenied, err.Error()) @@ -91,7 +91,7 @@ func GetPatchBytes(patchData *nodePatch) ([]byte, error) { return json.Marshal(patchData) } -func checkPermission(p authorization.PermissionChecker, ctx context.Context, permission permission.Permission) error { +func checkPermission(p authorization.PermissionChecker, ctx *armadacontext.Context, permission permission.Permission) error { if !p.UserHasPermission(ctx, permission) { return fmt.Errorf("user %s does not have permission %s", authorization.GetPrincipal(ctx).GetName(), permission) } diff --git a/internal/binoculars/service/cordon_test.go b/internal/binoculars/service/cordon_test.go index 5a1cce961b9..eadac72fd8e 100644 --- a/internal/binoculars/service/cordon_test.go +++ b/internal/binoculars/service/cordon_test.go @@ -6,6 +6,7 @@ import ( "fmt" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -19,6 +20,7 @@ import ( clientTesting "k8s.io/client-go/testing" "github.com/armadaproject/armada/internal/binoculars/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/auth/permission" "github.com/armadaproject/armada/pkg/api/binoculars" @@ -79,7 +81,7 @@ func TestCordonNode(t *testing.T) { cordonService, client := setupTest(t, cordonConfig, FakePermissionChecker{ReturnValue: true}) ctx := authorization.WithPrincipal(context.Background(), principal) - err := cordonService.CordonNode(ctx, &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), &binoculars.CordonRequest{ NodeName: defaultNode.Name, }) assert.Nil(t, err) @@ -96,7 +98,7 @@ func TestCordonNode(t *testing.T) { assert.Equal(t, patch, tc.expectedPatch) // Assert resulting node is in expected state - node, err := client.CoreV1().Nodes().Get(context.Background(), defaultNode.Name, metav1.GetOptions{}) + node, err := client.CoreV1().Nodes().Get(armadacontext.Background(), defaultNode.Name, metav1.GetOptions{}) assert.Nil(t, err) assert.Equal(t, node.Spec.Unschedulable, true) assert.Equal(t, node.Labels, tc.expectedLabels) @@ -107,7 +109,7 @@ func TestCordonNode(t *testing.T) { func TestCordonNode_InvalidNodeName(t *testing.T) { cordonService, _ := setupTest(t, defaultCordonConfig, FakePermissionChecker{ReturnValue: true}) - err := cordonService.CordonNode(context.Background(), &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.Background(), &binoculars.CordonRequest{ NodeName: "non-existent-node", }) @@ -117,7 +119,7 @@ func TestCordonNode_InvalidNodeName(t *testing.T) { func TestCordonNode_Unauthenticated(t *testing.T) { cordonService, _ := setupTest(t, defaultCordonConfig, FakePermissionChecker{ReturnValue: false}) - err := cordonService.CordonNode(context.Background(), &binoculars.CordonRequest{ + err := cordonService.CordonNode(armadacontext.Background(), &binoculars.CordonRequest{ NodeName: defaultNode.Name, }) @@ -131,7 +133,7 @@ func setupTest(t *testing.T, config configuration.CordonConfiguration, permissio client := fake.NewSimpleClientset() clientProvider := &FakeClientProvider{FakeClient: client} - _, err := client.CoreV1().Nodes().Create(context.Background(), defaultNode, metav1.CreateOptions{}) + _, err := client.CoreV1().Nodes().Create(armadacontext.Background(), defaultNode, metav1.CreateOptions{}) require.NoError(t, err) client.Fake.ClearActions() diff --git a/internal/binoculars/service/logs.go b/internal/binoculars/service/logs.go index 49801758292..ac72215f67e 100644 --- a/internal/binoculars/service/logs.go +++ b/internal/binoculars/service/logs.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "strings" "time" @@ -10,13 +9,14 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/cluster" "github.com/armadaproject/armada/pkg/api/binoculars" ) type LogService interface { - GetLogs(ctx context.Context, params *LogParams) ([]*binoculars.LogLine, error) + GetLogs(ctx *armadacontext.Context, params *LogParams) ([]*binoculars.LogLine, error) } type LogParams struct { @@ -37,7 +37,7 @@ func NewKubernetesLogService(clientProvider cluster.KubernetesClientProvider) *K return &KubernetesLogService{clientProvider: clientProvider} } -func (l *KubernetesLogService) GetLogs(ctx context.Context, params *LogParams) ([]*binoculars.LogLine, error) { +func (l *KubernetesLogService) GetLogs(ctx *armadacontext.Context, params *LogParams) ([]*binoculars.LogLine, error) { client, err := l.clientProvider.ClientForUser(params.Principal.GetName(), params.Principal.GetGroupNames()) if err != nil { return nil, err diff --git a/internal/common/app/app.go b/internal/common/app/app.go index bd35f7a5a8f..25ce1e828b0 100644 --- a/internal/common/app/app.go +++ b/internal/common/app/app.go @@ -1,15 +1,16 @@ package app import ( - "context" "os" "os/signal" "syscall" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // CreateContextWithShutdown returns a context that will report done when a SIGTERM is received -func CreateContextWithShutdown() context.Context { - ctx, cancel := context.WithCancel(context.Background()) +func CreateContextWithShutdown() *armadacontext.Context { + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go func() { diff --git a/internal/common/armadacontext/armada_context.go b/internal/common/armadacontext/armada_context.go new file mode 100644 index 00000000000..a6985ee5df7 --- /dev/null +++ b/internal/common/armadacontext/armada_context.go @@ -0,0 +1,107 @@ +package armadacontext + +import ( + "context" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +// Context is an extension of Go's context which also includes a logger. This allows us to pass round a contextual logger +// while retaining type-safety +type Context struct { + context.Context + Log *logrus.Entry +} + +// Background creates an empty context with a default logger. It is analogous to context.Background() +func Background() *Context { + return &Context{ + Context: context.Background(), + Log: logrus.NewEntry(logrus.New()), + } +} + +// TODO creates an empty context with a default logger. It is analogous to context.TODO() +func TODO() *Context { + return &Context{ + Context: context.TODO(), + Log: logrus.NewEntry(logrus.New()), + } +} + +// FromGrpcCtx creates a context where the logger is extracted via ctxlogrus's Extract() method. +// Note that this will result in a no-op logger if a logger hasn't already been inserted into the context via ctxlogrus +func FromGrpcCtx(ctx context.Context) *Context { + log := ctxlogrus.Extract(ctx) + return New(ctx, log) +} + +// New returns an armada context that encapsulates both a go context and a logger +func New(ctx context.Context, log *logrus.Entry) *Context { + return &Context{ + Context: ctx, + Log: log, + } +} + +// WithCancel returns a copy of parent with a new Done channel. It is analogous to context.WithCancel() +func WithCancel(parent *Context) (*Context, context.CancelFunc) { + c, cancel := context.WithCancel(parent.Context) + return &Context{ + Context: c, + Log: parent.Log, + }, cancel +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted to be no later than d. +// It is analogous to context.WithDeadline() +func WithDeadline(parent *Context, d time.Time) (*Context, context.CancelFunc) { + c, cancel := context.WithDeadline(parent.Context, d) + return &Context{ + Context: c, + Log: parent.Log, + }, cancel +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). It is analogous to context.WithTimeout() +func WithTimeout(parent *Context, timeout time.Duration) (*Context, context.CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithLogField returns a copy of parent with the supplied key-value added to the logger +func WithLogField(parent *Context, key string, val interface{}) *Context { + return &Context{ + Context: parent.Context, + Log: parent.Log.WithField(key, val), + } +} + +// WithLogFields returns a copy of parent with the supplied key-values added to the logger +func WithLogFields(parent *Context, fields logrus.Fields) *Context { + return &Context{ + Context: parent.Context, + Log: parent.Log.WithFields(fields), + } +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. It is analogous to context.WithValue() +func WithValue(parent *Context, key, val any) *Context { + return &Context{ + Context: context.WithValue(parent, key, val), + Log: parent.Log, + } +} + +// ErrGroup returns a new Error Group and an associated Context derived from ctx. +// It is analogous to errgroup.WithContext(ctx) +func ErrGroup(ctx *Context) (*errgroup.Group, *Context) { + group, goctx := errgroup.WithContext(ctx) + return group, &Context{ + Context: goctx, + Log: ctx.Log, + } +} diff --git a/internal/common/armadacontext/armada_context_test.go b/internal/common/armadacontext/armada_context_test.go new file mode 100644 index 00000000000..a98d7b611df --- /dev/null +++ b/internal/common/armadacontext/armada_context_test.go @@ -0,0 +1,89 @@ +package armadacontext + +import ( + "context" + "testing" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +var defaultLogger = logrus.WithField("foo", "bar") + +func TestNew(t *testing.T) { + ctx := New(context.Background(), defaultLogger) + require.Equal(t, defaultLogger, ctx.Log) + require.Equal(t, context.Background(), ctx.Context) +} + +func TestFromGrpcContext(t *testing.T) { + grpcCtx := ctxlogrus.ToContext(context.Background(), defaultLogger) + ctx := FromGrpcCtx(grpcCtx) + require.Equal(t, grpcCtx, ctx.Context) + require.Equal(t, defaultLogger, ctx.Log) +} + +func TestBackground(t *testing.T) { + ctx := Background() + require.Equal(t, ctx.Context, context.Background()) +} + +func TestTODO(t *testing.T) { + ctx := TODO() + require.Equal(t, ctx.Context, context.TODO()) +} + +func TestWithLogField(t *testing.T) { + ctx := WithLogField(Background(), "fish", "chips") + require.Equal(t, context.Background(), ctx.Context) + require.Equal(t, logrus.Fields{"fish": "chips"}, ctx.Log.Data) +} + +func TestWithLogFields(t *testing.T) { + ctx := WithLogFields(Background(), logrus.Fields{"fish": "chips", "salt": "pepper"}) + require.Equal(t, context.Background(), ctx.Context) + require.Equal(t, logrus.Fields{"fish": "chips", "salt": "pepper"}, ctx.Log.Data) +} + +func TestWithTimeout(t *testing.T) { + ctx, _ := WithTimeout(Background(), 100*time.Millisecond) + testDeadline(t, ctx) +} + +func TestWithDeadline(t *testing.T) { + ctx, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + testDeadline(t, ctx) +} + +func TestWithValue(t *testing.T) { + ctx := WithValue(Background(), "foo", "bar") + require.Equal(t, "bar", ctx.Value("foo")) +} + +func testDeadline(t *testing.T, c *Context) { + t.Helper() + d := quiescent(t) + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + t.Fatalf("context not timed out after %v", d) + case <-c.Done(): + } + if e := c.Err(); e != context.DeadlineExceeded { + t.Errorf("c.Err() == %v; want %v", e, context.DeadlineExceeded) + } +} + +func quiescent(t *testing.T) time.Duration { + deadline, ok := t.Deadline() + if !ok { + return 5 * time.Second + } + + const arbitraryCleanupMargin = 1 * time.Second + return time.Until(deadline) - arbitraryCleanupMargin +} diff --git a/internal/common/auth/authorization/kubernetes_test.go b/internal/common/auth/authorization/kubernetes_test.go index 9493c71f80a..eef827f9add 100644 --- a/internal/common/auth/authorization/kubernetes_test.go +++ b/internal/common/auth/authorization/kubernetes_test.go @@ -10,11 +10,10 @@ import ( "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" - authv1 "k8s.io/api/authentication/v1" - "k8s.io/apimachinery/pkg/util/clock" - "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" + authv1 "k8s.io/api/authentication/v1" + "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/common/auth/configuration" ) diff --git a/internal/common/certs/cached_certificate.go b/internal/common/certs/cached_certificate.go index 2588d0f5b50..72b7f6ea250 100644 --- a/internal/common/certs/cached_certificate.go +++ b/internal/common/certs/cached_certificate.go @@ -1,13 +1,14 @@ package certs import ( - "context" "crypto/tls" "os" "sync" "time" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) type CachedCertificateService struct { @@ -52,7 +53,7 @@ func (c *CachedCertificateService) updateCertificate(certificate *tls.Certificat c.certificate = certificate } -func (c *CachedCertificateService) Run(ctx context.Context) { +func (c *CachedCertificateService) Run(ctx *armadacontext.Context) { ticker := time.NewTicker(c.refreshInterval) for { select { diff --git a/internal/common/certs/cached_certificate_test.go b/internal/common/certs/cached_certificate_test.go index 7687c80fd63..4edd3efd376 100644 --- a/internal/common/certs/cached_certificate_test.go +++ b/internal/common/certs/cached_certificate_test.go @@ -2,7 +2,6 @@ package certs import ( "bytes" - "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -16,6 +15,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -96,7 +97,7 @@ func TestCachedCertificateService_ReloadsCertPeriodically_WhenUsingRun(t *testin assert.Equal(t, cert, cachedCertService.GetCertificate()) go func() { - cachedCertService.Run(context.Background()) + cachedCertService.Run(armadacontext.Background()) }() newCert, certData, keyData := createCerts(t) diff --git a/internal/common/client.go b/internal/common/client.go index 0b44c374d0b..afc5bb5c597 100644 --- a/internal/common/client.go +++ b/internal/common/client.go @@ -3,8 +3,10 @@ package common import ( "context" "time" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func ContextWithDefaultTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), 10*time.Second) +func ContextWithDefaultTimeout() (*armadacontext.Context, context.CancelFunc) { + return armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) } diff --git a/internal/common/database/db_testutil.go b/internal/common/database/db_testutil.go index a36affdef73..416b348d7d8 100644 --- a/internal/common/database/db_testutil.go +++ b/internal/common/database/db_testutil.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "github.com/jackc/pgx/v5" @@ -10,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" ) @@ -17,7 +17,7 @@ import ( // migrations: perform the list of migrations before entering the action callback // action: callback for client code func WithTestDb(migrations []Migration, action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() // Connect and create a dedicated database for the test dbName := "test_" + util.NewULID() @@ -67,7 +67,7 @@ func WithTestDb(migrations []Migration, action func(db *pgxpool.Pool) error) err // config: PostgresConfig to specify connection details to database // action: callback for client code func WithTestDbCustom(migrations []Migration, config configuration.PostgresConfig, action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() testDbPool, err := OpenPgxPool(config) if err != nil { diff --git a/internal/common/database/functions.go b/internal/common/database/functions.go index 5446f7cd0e1..17f3334efab 100644 --- a/internal/common/database/functions.go +++ b/internal/common/database/functions.go @@ -1,7 +1,6 @@ package database import ( - "context" "database/sql" "fmt" "strings" @@ -13,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" ) func CreateConnectionString(values map[string]string) string { @@ -26,20 +26,20 @@ func CreateConnectionString(values map[string]string) string { } func OpenPgxConn(config configuration.PostgresConfig) (*pgx.Conn, error) { - db, err := pgx.Connect(context.Background(), CreateConnectionString(config.Connection)) + db, err := pgx.Connect(armadacontext.Background(), CreateConnectionString(config.Connection)) if err != nil { return nil, err } - err = db.Ping(context.Background()) + err = db.Ping(armadacontext.Background()) return db, err } func OpenPgxPool(config configuration.PostgresConfig) (*pgxpool.Pool, error) { - db, err := pgxpool.New(context.Background(), CreateConnectionString(config.Connection)) + db, err := pgxpool.New(armadacontext.Background(), CreateConnectionString(config.Connection)) if err != nil { return nil, err } - err = db.Ping(context.Background()) + err = db.Ping(armadacontext.Background()) return db, err } diff --git a/internal/common/database/migrations.go b/internal/common/database/migrations.go index 164c75b313d..b515c94f7fb 100644 --- a/internal/common/database/migrations.go +++ b/internal/common/database/migrations.go @@ -2,7 +2,6 @@ package database import ( "bytes" - "context" "io/fs" "path" "sort" @@ -11,6 +10,8 @@ import ( stakikfs "github.com/rakyll/statik/fs" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Migration represents a single, versioned database migration script @@ -28,7 +29,7 @@ func NewMigration(id int, name string, sql string) Migration { } } -func UpdateDatabase(ctx context.Context, db Querier, migrations []Migration) error { +func UpdateDatabase(ctx *armadacontext.Context, db Querier, migrations []Migration) error { log.Info("Updating postgres...") version, err := readVersion(ctx, db) if err != nil { @@ -55,7 +56,7 @@ func UpdateDatabase(ctx context.Context, db Querier, migrations []Migration) err return nil } -func readVersion(ctx context.Context, db Querier) (int, error) { +func readVersion(ctx *armadacontext.Context, db Querier) (int, error) { _, err := db.Exec(ctx, `CREATE SEQUENCE IF NOT EXISTS database_version START WITH 0 MINVALUE 0;`) if err != nil { @@ -75,7 +76,7 @@ func readVersion(ctx context.Context, db Querier) (int, error) { return version, err } -func setVersion(ctx context.Context, db Querier, version int) error { +func setVersion(ctx *armadacontext.Context, db Querier, version int) error { _, err := db.Exec(ctx, `SELECT setval('database_version', $1)`, version) return err } diff --git a/internal/common/database/types/types.go b/internal/common/database/types/types.go index eb4f8d426be..2171d10bad1 100644 --- a/internal/common/database/types/types.go +++ b/internal/common/database/types/types.go @@ -1,10 +1,10 @@ package types import ( - "context" "database/sql" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" ) type DatabaseConnection interface { @@ -16,25 +16,25 @@ type DatabaseConnection interface { // executing queries, and starting transactions. type DatabaseConn interface { // Close closes the database connection. It returns any error encountered during the closing operation. - Close(context.Context) error + Close(*armadacontext.Context) error // Ping pings the database to check the connection. It returns any error encountered during the ping operation. - Ping(context.Context) error + Ping(*armadacontext.Context) error // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // QueryRow executes a query that returns one row. It returns a DatabaseRow interface representing the result row, and any error encountered. - QueryRow(context.Context, string, ...any) DatabaseRow + QueryRow(*armadacontext.Context, string, ...any) DatabaseRow // BeginTx starts a transcation with the given DatabaseTxOptions, or returns an error if any occurred. - BeginTx(context.Context, DatabaseTxOptions) (DatabaseTx, error) + BeginTx(*armadacontext.Context, DatabaseTxOptions) (DatabaseTx, error) // BeginTxFunc starts a transaction and executes the given function within the transaction. It the function runs successfully, BeginTxFunc commits the transaction, otherwise it rolls back and return an errorr. - BeginTxFunc(context.Context, DatabaseTxOptions, func(DatabaseTx) error) error + BeginTxFunc(*armadacontext.Context, DatabaseTxOptions, func(DatabaseTx) error) error } type DatabaseTxOptions struct { @@ -47,52 +47,52 @@ type DatabaseTxOptions struct { // managing transactions, and performing bulk insertions. type DatabaseTx interface { // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. // It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // QueryRow executes a query that returns one row. // It returns a DatabaseRow interface representing the result row, and any error encountered. - QueryRow(context.Context, string, ...any) DatabaseRow + QueryRow(*armadacontext.Context, string, ...any) DatabaseRow // CopyFrom performs a bulk insertion of data into a specified table. // It accepts the table name, column names, and a slice of rows representing the data to be inserted. It returns the number of rows inserted and any error encountered. - CopyFrom(ctx context.Context, tableName string, columnNames []string, rows [][]any) (int64, error) + CopyFrom(ctx *armadacontext.Context, tableName string, columnNames []string, rows [][]any) (int64, error) // Commit commits the transaction. It returns any error encountered during the commit operation. - Commit(context.Context) error + Commit(*armadacontext.Context) error // Rollback rolls back the transaction. It returns any error encountered during the rollback operation. - Rollback(context.Context) error + Rollback(*armadacontext.Context) error } // DatabasePool represents a database connection pool interface that provides methods for acquiring and managing database connections. type DatabasePool interface { // Acquire acquires a database connection from the pool. // It takes a context and returns a DatabaseConn representing the acquired connection and any encountered error. - Acquire(context.Context) (DatabaseConn, error) + Acquire(*armadacontext.Context) (DatabaseConn, error) // Ping pings the database to check the connection. It returns any error encountered during the ping operation. - Ping(context.Context) error + Ping(*armadacontext.Context) error // Close closes the database connection. It returns any error encountered during the closing operation. Close() // Exec executes a query that doesn't return rows. It returns any error encountered. - Exec(context.Context, string, ...any) (any, error) + Exec(*armadacontext.Context, string, ...any) (any, error) // Query executes a query that returns multiple rows. // It returns a DatabaseRows interface that allows you to iterate over the result set, and any error encountered. - Query(context.Context, string, ...any) (DatabaseRows, error) + Query(*armadacontext.Context, string, ...any) (DatabaseRows, error) // BeginTx starts a transcation with the given DatabaseTxOptions, or returns an error if any occurred. - BeginTx(context.Context, DatabaseTxOptions) (DatabaseTx, error) + BeginTx(*armadacontext.Context, DatabaseTxOptions) (DatabaseTx, error) // BeginTxFunc starts a transaction and executes the given function within the transaction. // It the function runs successfully, BeginTxFunc commits the transaction, otherwise it rolls back and return an error. - BeginTxFunc(context.Context, DatabaseTxOptions, func(DatabaseTx) error) error + BeginTxFunc(*armadacontext.Context, DatabaseTxOptions, func(DatabaseTx) error) error } // DatabaseRow represents a single row in a result set. diff --git a/internal/common/database/upsert.go b/internal/common/database/upsert.go index 23f27164f9b..5df05c67918 100644 --- a/internal/common/database/upsert.go +++ b/internal/common/database/upsert.go @@ -1,19 +1,19 @@ package database import ( - "context" "fmt" "reflect" "strings" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func UpsertWithTransaction[T any](ctx context.Context, db *pgxpool.Pool, tableName string, records []T) error { +func UpsertWithTransaction[T any](ctx *armadacontext.Context, db *pgxpool.Pool, tableName string, records []T) error { if len(records) == 0 { return nil } @@ -50,7 +50,7 @@ func UpsertWithTransaction[T any](ctx context.Context, db *pgxpool.Pool, tableNa // // ) // I.e., it should omit everything before and after the "(" and ")", respectively. -func Upsert[T any](ctx context.Context, tx pgx.Tx, tableName string, records []T) error { +func Upsert[T any](ctx *armadacontext.Context, tx pgx.Tx, tableName string, records []T) error { if len(records) < 1 { return nil } diff --git a/internal/common/database/upsert_test.go b/internal/common/database/upsert_test.go index b1329921c1e..638d15ac494 100644 --- a/internal/common/database/upsert_test.go +++ b/internal/common/database/upsert_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -9,6 +8,8 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Used for tests. @@ -55,7 +56,7 @@ func TestNamesValuesFromRecordPointer(t *testing.T) { } func TestUpsertWithTransaction(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Hour) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Hour) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Insert rows, read them back, and compare. @@ -90,7 +91,7 @@ func TestUpsertWithTransaction(t *testing.T) { } func TestConcurrency(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Each thread inserts non-overlapping rows, reads them back, and compares. @@ -125,7 +126,7 @@ func TestConcurrency(t *testing.T) { } func TestAutoIncrement(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := withDb(func(db *pgxpool.Pool) error { // Insert two rows. These should automatically get auto-incrementing serial numbers. @@ -207,7 +208,7 @@ func setMessageToExecutor(runs []Record, executor string) { } func selectRecords(db *pgxpool.Pool) ([]Record, error) { - rows, err := db.Query(context.Background(), fmt.Sprintf("SELECT id, message, value, serial FROM %s order by value", TABLE_NAME)) + rows, err := db.Query(armadacontext.Background(), fmt.Sprintf("SELECT id, message, value, serial FROM %s order by value", TABLE_NAME)) if err != nil { return nil, err } diff --git a/internal/common/etcdhealth/etcdhealth.go b/internal/common/etcdhealth/etcdhealth.go index 804a89542f4..49be27a22fe 100644 --- a/internal/common/etcdhealth/etcdhealth.go +++ b/internal/common/etcdhealth/etcdhealth.go @@ -1,7 +1,6 @@ package etcdhealth import ( - "context" "sync" "time" @@ -9,6 +8,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/healthmonitor" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/metrics" @@ -184,7 +184,7 @@ func (srv *EtcdReplicaHealthMonitor) sizeFraction() float64 { return srv.etcdSizeBytes / srv.etcdCapacityBytes } -func (srv *EtcdReplicaHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { +func (srv *EtcdReplicaHealthMonitor) Run(ctx *armadacontext.Context, log *logrus.Entry) error { log = log.WithField("service", "EtcdHealthMonitor") log.Info("starting etcd health monitor") defer log.Info("stopping etcd health monitor") @@ -264,7 +264,7 @@ func (srv *EtcdReplicaHealthMonitor) setCapacityBytesFromMetrics(metrics map[str // BlockUntilNextMetricsCollection blocks until the next metrics collection has completed, // or until ctx is cancelled, whichever occurs first. -func (srv *EtcdReplicaHealthMonitor) BlockUntilNextMetricsCollection(ctx context.Context) { +func (srv *EtcdReplicaHealthMonitor) BlockUntilNextMetricsCollection(ctx *armadacontext.Context) { c := make(chan struct{}) srv.mu.Lock() srv.watchers = append(srv.watchers, c) diff --git a/internal/common/etcdhealth/etcdhealth_test.go b/internal/common/etcdhealth/etcdhealth_test.go index 22435861a61..474d4df0e3a 100644 --- a/internal/common/etcdhealth/etcdhealth_test.go +++ b/internal/common/etcdhealth/etcdhealth_test.go @@ -1,14 +1,13 @@ package etcdhealth import ( - "context" "testing" "time" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/healthmonitor" "github.com/armadaproject/armada/internal/common/metrics" ) @@ -24,9 +23,9 @@ func TestEtcdReplicaHealthMonitor(t *testing.T) { assert.NoError(t, err) // Start the metrics collection service. - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) defer cancel() - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) g.Go(func() error { return hm.Run(ctx, logrus.NewEntry(logrus.New())) }) // Should still be unavailable due to missing metrics. diff --git a/internal/common/eventutil/eventutil.go b/internal/common/eventutil/eventutil.go index 05ee5d473c9..10d5baf4885 100644 --- a/internal/common/eventutil/eventutil.go +++ b/internal/common/eventutil/eventutil.go @@ -1,7 +1,6 @@ package eventutil import ( - "context" "fmt" "math" "time" @@ -14,6 +13,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -25,7 +25,7 @@ import ( // UnmarshalEventSequence returns an EventSequence object contained in a byte buffer // after validating that the resulting EventSequence is valid. -func UnmarshalEventSequence(ctx context.Context, payload []byte) (*armadaevents.EventSequence, error) { +func UnmarshalEventSequence(ctx *armadacontext.Context, payload []byte) (*armadaevents.EventSequence, error) { sequence := &armadaevents.EventSequence{} err := proto.Unmarshal(payload, sequence) if err != nil { diff --git a/internal/common/eventutil/sequence_from_message.go b/internal/common/eventutil/sequence_from_message.go deleted file mode 100644 index cc1749c392e..00000000000 --- a/internal/common/eventutil/sequence_from_message.go +++ /dev/null @@ -1,193 +0,0 @@ -package eventutil - -import ( - "context" - "time" - - "github.com/apache/pulsar-client-go/pulsar" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" - - "github.com/armadaproject/armada/internal/common/logging" - "github.com/armadaproject/armada/pkg/armadaevents" -) - -// PulsarToChannel is a service for receiving messages from Pulsar and forwarding those on C. -type SequenceFromMessage struct { - In chan pulsar.Message - Out chan *EventSequenceWithMessageIds -} - -// EventSequenceWithMessageIds bundles an event sequence with -// all the ids of all Pulsar messages that were consumed to produce it. -type EventSequenceWithMessageIds struct { - Sequence *armadaevents.EventSequence - MessageIds []pulsar.MessageID -} - -func NewSequenceFromMessage(in chan pulsar.Message) *SequenceFromMessage { - return &SequenceFromMessage{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - } -} - -func (srv *SequenceFromMessage) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case msg := <-srv.In: - if msg == nil { - break - } - sequence, err := UnmarshalEventSequence(ctx, msg.Payload()) - if err != nil { - logging.WithStacktrace(log, err).WithField("messageid", msg.ID()).Error("failed to unmarshal event sequence") - break - } - - sequenceWithMessageIds := &EventSequenceWithMessageIds{ - Sequence: sequence, - MessageIds: []pulsar.MessageID{msg.ID()}, - } - select { - case <-ctx.Done(): - case srv.Out <- sequenceWithMessageIds: - } - } - } -} - -// SequenceCompacter reads sequences and produces compacted sequences. -// Compacted sequences are created by combining events in sequences with the -type SequenceCompacter struct { - In chan *EventSequenceWithMessageIds - Out chan *EventSequenceWithMessageIds - // Buffer messages for at most this long before forwarding on the outgoing channel. - Interval time.Duration - // Max number of events to buffer. - MaxEvents int - // Buffer of events to be compacted and sent. - buffer []*EventSequenceWithMessageIds - // Number of events collected so far. - numEvents int -} - -func NewSequenceCompacter(in chan *EventSequenceWithMessageIds) *SequenceCompacter { - return &SequenceCompacter{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - Interval: 5 * time.Second, - MaxEvents: 10000, - } -} - -func (srv *SequenceCompacter) Run(ctx context.Context) error { - ticker := time.NewTicker(srv.Interval) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - err := srv.compactAndSend(ctx) - if err != nil { - return err - } - case sequenceWithIds := <-srv.In: - if sequenceWithIds == nil || sequenceWithIds.Sequence == nil { - break - } - srv.buffer = append(srv.buffer, sequenceWithIds) - srv.numEvents += len(sequenceWithIds.Sequence.Events) - if srv.numEvents > srv.MaxEvents { - err := srv.compactAndSend(ctx) - if err != nil { - return err - } - } - } - } -} - -func (srv *SequenceCompacter) compactAndSend(ctx context.Context) error { - if len(srv.buffer) == 0 { - return nil - } - - // Compact the event sequences. - // Note that we can't be sure of the number of message ids. - messageIds := make([]pulsar.MessageID, 0, len(srv.buffer)) - sequences := make([]*armadaevents.EventSequence, len(srv.buffer)) - for i, sequenceWithIds := range srv.buffer { - messageIds = append(messageIds, sequenceWithIds.MessageIds...) - sequences[i] = sequenceWithIds.Sequence - } - sequences = CompactEventSequences(sequences) - - for i, sequence := range sequences { - sequenceWithIds := &EventSequenceWithMessageIds{ - Sequence: sequence, - } - - // Add all message ids to the last sequence to be produced. - // To avoid later ack'ing messages the data of which has not yet been processed. - if i == len(sequences)-1 { - sequenceWithIds.MessageIds = messageIds - } - - select { - case <-ctx.Done(): - return ctx.Err() - case srv.Out <- sequenceWithIds: - } - } - - // Empty the buffer. - srv.buffer = nil - srv.numEvents = 0 - - return nil -} - -// EventFilter calls filter once for each event, -// and events for which filter returns false are discarded. -type EventFilter struct { - In chan *EventSequenceWithMessageIds - Out chan *EventSequenceWithMessageIds - // Filter function. Discard on returning false. - filter func(*armadaevents.EventSequence_Event) bool -} - -func NewEventFilter(in chan *EventSequenceWithMessageIds, filter func(*armadaevents.EventSequence_Event) bool) *EventFilter { - return &EventFilter{ - In: in, - Out: make(chan *EventSequenceWithMessageIds), - filter: filter, - } -} - -func (srv *EventFilter) Run(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - case sequenceWithIds := <-srv.In: - if sequenceWithIds == nil { - break - } - events := make([]*armadaevents.EventSequence_Event, 0, len(sequenceWithIds.Sequence.Events)) - for _, event := range sequenceWithIds.Sequence.Events { - if srv.filter(event) { - events = append(events, event) - } - } - sequenceWithIds.Sequence.Events = events - - select { - case <-ctx.Done(): - case srv.Out <- sequenceWithIds: - } - } - } -} diff --git a/internal/common/eventutil/sequence_from_message_test.go b/internal/common/eventutil/sequence_from_message_test.go deleted file mode 100644 index a4a1812b207..00000000000 --- a/internal/common/eventutil/sequence_from_message_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package eventutil - -// import ( -// "context" -// "testing" -// "time" - -// "github.com/armadaproject/armada/internal/pulsarutils" -// "github.com/armadaproject/armada/pkg/armadaevents" -// "github.com/apache/pulsar-client-go/pulsar" -// ) - -// func TestSequenceCompacter(t *testing.T) { - -// } - -// func TestEventFilter(t *testing.T) { -// tests := map[string]struct { -// filter func(*armadaevents.EventSequence_Event) bool -// n int // Number of event expected to pass the filter -// }{ -// "filter all": { -// filter: func(a *armadaevents.EventSequence_Event) bool { -// return false -// }, -// n: 0, -// }, -// "filter none": { -// filter: func(a *armadaevents.EventSequence_Event) bool { -// return true -// }, -// n: 1, -// }, -// } -// for name, tc := range tests { -// t.Run(name, func(t *testing.T) { -// C := make(chan *EventSequenceWithMessageIds, 1) -// eventFilter := NewEventFilter(C, tc.filter) -// ctx, _ := context.WithTimeout(context.Background(), time.Second) -// sequence := &EventSequenceWithMessageIds{ -// Sequence: &armadaevents.EventSequence{ -// Events: []*armadaevents.EventSequence_Event{ -// {Event: nil}, -// {Event: &armadaevents.EventSequence_Event_SubmitJob{}}, -// }, -// }, -// MessageIds: []pulsar.MessageID{pulsarutils.New(0, i, 0, 0)}, -// } -// C <- sequence - -// }) -// } -// } - -// func generateEvents(ctx context.Context, out chan *EventSequenceWithMessageIds) error { -// var i int64 -// for { -// sequence := EventSequenceWithMessageIds{ -// Sequence: &armadaevents.EventSequence{ -// Events: []*armadaevents.EventSequence_Event{ -// {Event: nil}, -// {Event: &armadaevents.EventSequence_Event_SubmitJob{}}, -// }, -// }, -// MessageIds: []pulsar.MessageID{pulsarutils.New(0, i, 0, 0)}, -// } -// select { -// case <-ctx.Done(): -// return ctx.Err() -// case out <- &sequence: -// } -// } -// } diff --git a/internal/common/grpc/grpc.go b/internal/common/grpc/grpc.go index 5f73c3801c0..43707dffadf 100644 --- a/internal/common/grpc/grpc.go +++ b/internal/common/grpc/grpc.go @@ -1,7 +1,6 @@ package grpc import ( - "context" "crypto/tls" "fmt" "net" @@ -23,6 +22,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/certs" @@ -91,7 +91,7 @@ func CreateGrpcServer( if tlsConfig.Enabled { cachedCertificateService := certs.NewCachedCertificateService(tlsConfig.CertPath, tlsConfig.KeyPath, time.Minute) go func() { - cachedCertificateService.Run(context.Background()) + cachedCertificateService.Run(armadacontext.Background()) }() tlsCreds := credentials.NewTLS(&tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -130,7 +130,7 @@ func Listen(port uint16, grpcServer *grpc.Server, wg *sync.WaitGroup) { // CreateShutdownHandler returns a function that shuts down the grpcServer when the context is closed. // The server is given gracePeriod to perform a graceful showdown and is then forcably stopped if necessary -func CreateShutdownHandler(ctx context.Context, gracePeriod time.Duration, grpcServer *grpc.Server) func() error { +func CreateShutdownHandler(ctx *armadacontext.Context, gracePeriod time.Duration, grpcServer *grpc.Server) func() error { return func() error { <-ctx.Done() go func() { diff --git a/internal/common/healthmonitor/healthmonitor.go b/internal/common/healthmonitor/healthmonitor.go index aa196aaffda..d5c6b151c1e 100644 --- a/internal/common/healthmonitor/healthmonitor.go +++ b/internal/common/healthmonitor/healthmonitor.go @@ -1,10 +1,10 @@ package healthmonitor import ( - "context" - "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -25,5 +25,5 @@ type HealthMonitor interface { // Run initialises and starts the health checker. // Run may be blocking and should be run within a separate goroutine. // Must be called before IsHealthy() or any prometheus.Collector interface methods. - Run(context.Context, *logrus.Entry) error + Run(*armadacontext.Context, *logrus.Entry) error } diff --git a/internal/common/healthmonitor/manualhealthmonitor.go b/internal/common/healthmonitor/manualhealthmonitor.go index 1bc8a6d5b62..7aa2f525068 100644 --- a/internal/common/healthmonitor/manualhealthmonitor.go +++ b/internal/common/healthmonitor/manualhealthmonitor.go @@ -1,11 +1,12 @@ package healthmonitor import ( - "context" "sync" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // ManualHealthMonitor is a manually controlled health monitor. @@ -46,7 +47,7 @@ func (srv *ManualHealthMonitor) IsHealthy() (bool, string, error) { } } -func (srv *ManualHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { +func (srv *ManualHealthMonitor) Run(_ *armadacontext.Context, _ *logrus.Entry) error { return nil } diff --git a/internal/common/healthmonitor/multihealthmonitor.go b/internal/common/healthmonitor/multihealthmonitor.go index 8d9790fd91e..a9f03643d10 100644 --- a/internal/common/healthmonitor/multihealthmonitor.go +++ b/internal/common/healthmonitor/multihealthmonitor.go @@ -1,7 +1,6 @@ package healthmonitor import ( - "context" "fmt" "sync" @@ -9,7 +8,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" - "golang.org/x/sync/errgroup" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // MultiHealthMonitor wraps multiple HealthMonitors and itself implements the HealthMonitor interface. @@ -100,8 +100,8 @@ func (srv *MultiHealthMonitor) IsHealthy() (ok bool, reason string, err error) { } // Run initialises prometheus metrics and starts any child health checkers. -func (srv *MultiHealthMonitor) Run(ctx context.Context, log *logrus.Entry) error { - g, ctx := errgroup.WithContext(ctx) +func (srv *MultiHealthMonitor) Run(ctx *armadacontext.Context, log *logrus.Entry) error { + g, ctx := armadacontext.ErrGroup(ctx) for _, healthMonitor := range srv.healthMonitorsByName { healthMonitor := healthMonitor g.Go(func() error { return healthMonitor.Run(ctx, log) }) diff --git a/internal/common/ingest/batch.go b/internal/common/ingest/batch.go index 7f07c915855..f099f646fae 100644 --- a/internal/common/ingest/batch.go +++ b/internal/common/ingest/batch.go @@ -1,12 +1,13 @@ package ingest import ( - "context" "sync" "time" log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // Batcher batches up events from a channel. Batches are created whenever maxItems have been @@ -32,7 +33,7 @@ func NewBatcher[T any](input chan T, maxItems int, maxTimeout time.Duration, cal } } -func (b *Batcher[T]) Run(ctx context.Context) { +func (b *Batcher[T]) Run(ctx *armadacontext.Context) { for { b.buffer = []T{} expire := b.clock.After(b.maxTimeout) diff --git a/internal/common/ingest/batch_test.go b/internal/common/ingest/batch_test.go index 4c9fee650a1..a906dbc8258 100644 --- a/internal/common/ingest/batch_test.go +++ b/internal/common/ingest/batch_test.go @@ -5,11 +5,11 @@ import ( "testing" "time" - "golang.org/x/net/context" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const ( @@ -42,7 +42,7 @@ func (r *resultHolder) resultLength() int { } func TestBatch_MaxItems(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -67,7 +67,7 @@ func TestBatch_MaxItems(t *testing.T) { } func TestBatch_Time(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -89,7 +89,7 @@ func TestBatch_Time(t *testing.T) { } func TestBatch_Time_WithIntialQuiet(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() @@ -120,7 +120,7 @@ func TestBatch_Time_WithIntialQuiet(t *testing.T) { cancel() } -func waitForBufferLength(ctx context.Context, batcher *Batcher[int], numEvents int) error { +func waitForBufferLength(ctx *armadacontext.Context, batcher *Batcher[int], numEvents int) error { ticker := time.NewTicker(5 * time.Millisecond) for { select { @@ -134,7 +134,7 @@ func waitForBufferLength(ctx context.Context, batcher *Batcher[int], numEvents i } } -func waitForExpectedEvents(ctx context.Context, rh *resultHolder, numEvents int) { +func waitForExpectedEvents(ctx *armadacontext.Context, rh *resultHolder, numEvents int) { done := false ticker := time.NewTicker(5 * time.Millisecond) for !done { diff --git a/internal/common/ingest/ingestion_pipeline.go b/internal/common/ingest/ingestion_pipeline.go index 2b5e9a9e783..4236473d360 100644 --- a/internal/common/ingest/ingestion_pipeline.go +++ b/internal/common/ingest/ingestion_pipeline.go @@ -1,16 +1,17 @@ package ingest import ( + "context" "sync" "time" "github.com/apache/pulsar-client-go/pulsar" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/net/context" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -27,7 +28,7 @@ type HasPulsarMessageIds interface { // InstructionConverter should be implemented by structs that can convert a batch of event sequences into an object // suitable for passing to the sink type InstructionConverter[T HasPulsarMessageIds] interface { - Convert(ctx context.Context, msg *EventSequencesWithIds) T + Convert(ctx *armadacontext.Context, msg *EventSequencesWithIds) T } // Sink should be implemented by the struct responsible for putting the data in its final resting place, e.g. a @@ -35,7 +36,7 @@ type InstructionConverter[T HasPulsarMessageIds] interface { type Sink[T HasPulsarMessageIds] interface { // Store should persist the sink. The store is responsible for retrying failed attempts and should only return an error // When it is satisfied that operation cannot be retries. - Store(ctx context.Context, msg T) error + Store(ctx *armadacontext.Context, msg T) error } // EventSequencesWithIds consists of a batch of Event Sequences along with the corresponding Pulsar Message Ids @@ -122,7 +123,7 @@ func NewFilteredMsgIngestionPipeline[T HasPulsarMessageIds]( } // Run will run the ingestion pipeline until the supplied context is shut down -func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { +func (ingester *IngestionPipeline[T]) Run(ctx *armadacontext.Context) error { shutdownMetricServer := common.ServeMetrics(ingester.metricsConfig.Port) defer shutdownMetricServer() @@ -147,7 +148,7 @@ func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { // Set up a context that n seconds after ctx // This gives the rest of the pipeline a chance to flush pending messages - pipelineShutdownContext, cancel := context.WithCancel(context.Background()) + pipelineShutdownContext, cancel := armadacontext.WithCancel(armadacontext.Background()) go func() { for { select { @@ -206,7 +207,7 @@ func (ingester *IngestionPipeline[T]) Run(ctx context.Context) error { } else { for _, msgId := range msg.GetMessageIDs() { util.RetryUntilSuccess( - context.Background(), + armadacontext.Background(), func() error { return ingester.consumer.AckID(msgId) }, func(err error) { log.WithError(err).Warnf("Pulsar ack failed; backing off for %s", ingester.pulsarConfig.BackoffTime) @@ -265,7 +266,7 @@ func unmarshalEventSequences(batch []pulsar.Message, msgFilter func(msg pulsar.M } // Try and unmarshall the proto - es, err := eventutil.UnmarshalEventSequence(context.Background(), msg.Payload()) + es, err := eventutil.UnmarshalEventSequence(armadacontext.Background(), msg.Payload()) if err != nil { metrics.RecordPulsarMessageError(commonmetrics.PulsarMessageErrorDeserialization) log.WithError(err).Warnf("Could not unmarshal proto for msg %s", msg.ID()) diff --git a/internal/common/ingest/ingestion_pipeline_test.go b/internal/common/ingest/ingestion_pipeline_test.go index da0d653b39a..53dd6a7a39b 100644 --- a/internal/common/ingest/ingestion_pipeline_test.go +++ b/internal/common/ingest/ingestion_pipeline_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/pkg/armadaevents" @@ -191,7 +192,7 @@ func newSimpleConverter(t *testing.T) InstructionConverter[*simpleMessages] { return &simpleConverter{t} } -func (s *simpleConverter) Convert(_ context.Context, msg *EventSequencesWithIds) *simpleMessages { +func (s *simpleConverter) Convert(_ *armadacontext.Context, msg *EventSequencesWithIds) *simpleMessages { s.t.Helper() assert.Len(s.t, msg.EventSequences, len(msg.MessageIds)) var converted []*simpleMessage @@ -218,7 +219,7 @@ func newSimpleSink(t *testing.T) *simpleSink { } } -func (s *simpleSink) Store(_ context.Context, msg *simpleMessages) error { +func (s *simpleSink) Store(_ *armadacontext.Context, msg *simpleMessages) error { for _, simpleMessage := range msg.msgs { s.simpleMessages[simpleMessage.id] = simpleMessage } @@ -236,7 +237,7 @@ func (s *simpleSink) assertDidProcess(messages []pulsar.Message) { } func TestRun_HappyPath_SingleMessage(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) + ctx, cancel := armadacontext.WithDeadline(armadacontext.Background(), time.Now().Add(10*time.Second)) messages := []pulsar.Message{ pulsarutils.NewPulsarMessage(1, baseTime, marshal(t, succeeded)), } @@ -257,7 +258,7 @@ func TestRun_HappyPath_SingleMessage(t *testing.T) { } func TestRun_HappyPath_MultipleMessages(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) + ctx, cancel := armadacontext.WithDeadline(armadacontext.Background(), time.Now().Add(10*time.Second)) messages := []pulsar.Message{ pulsarutils.NewPulsarMessage(1, baseTime, marshal(t, succeeded)), pulsarutils.NewPulsarMessage(2, baseTime.Add(1*time.Second), marshal(t, pendingAndRunning)), diff --git a/internal/common/pgkeyvalue/pgkeyvalue.go b/internal/common/pgkeyvalue/pgkeyvalue.go index 8476146d727..d3f5f7d9401 100644 --- a/internal/common/pgkeyvalue/pgkeyvalue.go +++ b/internal/common/pgkeyvalue/pgkeyvalue.go @@ -1,7 +1,6 @@ package pgkeyvalue import ( - "context" "fmt" "time" @@ -10,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/logging" @@ -34,7 +34,7 @@ type PGKeyValueStore struct { clock clock.Clock } -func New(ctx context.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueStore, error) { +func New(ctx *armadacontext.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueStore, error) { if db == nil { return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ Name: "db", @@ -60,7 +60,7 @@ func New(ctx context.Context, db *pgxpool.Pool, tableName string) (*PGKeyValueSt }, nil } -func (c *PGKeyValueStore) Load(ctx context.Context, keys []string) (map[string][]byte, error) { +func (c *PGKeyValueStore) Load(ctx *armadacontext.Context, keys []string) (map[string][]byte, error) { rows, err := c.db.Query(ctx, fmt.Sprintf("SELECT KEY, VALUE FROM %s WHERE KEY = any($1)", c.tableName), keys) if err != nil { return nil, errors.WithStack(err) @@ -78,7 +78,7 @@ func (c *PGKeyValueStore) Load(ctx context.Context, keys []string) (map[string][ return kv, nil } -func (c *PGKeyValueStore) Store(ctx context.Context, kvs map[string][]byte) error { +func (c *PGKeyValueStore) Store(ctx *armadacontext.Context, kvs map[string][]byte) error { data := make([]KeyValue, 0, len(kvs)) for k, v := range kvs { data = append(data, KeyValue{ @@ -90,7 +90,7 @@ func (c *PGKeyValueStore) Store(ctx context.Context, kvs map[string][]byte) erro return database.UpsertWithTransaction(ctx, c.db, c.tableName, data) } -func createTableIfNotExists(ctx context.Context, db *pgxpool.Pool, tableName string) error { +func createTableIfNotExists(ctx *armadacontext.Context, db *pgxpool.Pool, tableName string) error { _, err := db.Exec(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( key TEXT PRIMARY KEY, @@ -101,7 +101,7 @@ func createTableIfNotExists(ctx context.Context, db *pgxpool.Pool, tableName str } // Cleanup removes all key-value pairs older than lifespan. -func (c *PGKeyValueStore) cleanup(ctx context.Context, lifespan time.Duration) error { +func (c *PGKeyValueStore) cleanup(ctx *armadacontext.Context, lifespan time.Duration) error { sql := fmt.Sprintf("DELETE FROM %s WHERE (inserted <= $1);", c.tableName) _, err := c.db.Exec(ctx, sql, c.clock.Now().Add(-lifespan)) if err != nil { @@ -112,7 +112,7 @@ func (c *PGKeyValueStore) cleanup(ctx context.Context, lifespan time.Duration) e // PeriodicCleanup starts a goroutine that automatically runs the cleanup job // every interval until the provided context is cancelled. -func (c *PGKeyValueStore) PeriodicCleanup(ctx context.Context, interval time.Duration, lifespan time.Duration) error { +func (c *PGKeyValueStore) PeriodicCleanup(ctx *armadacontext.Context, interval time.Duration, lifespan time.Duration) error { log := logrus.StandardLogger().WithField("service", "PGKeyValueStoreCleanup") log.Info("service started") ticker := c.clock.NewTicker(interval) diff --git a/internal/common/pgkeyvalue/pgkeyvalue_test.go b/internal/common/pgkeyvalue/pgkeyvalue_test.go index c8a9beeb175..aa70c4ed7b9 100644 --- a/internal/common/pgkeyvalue/pgkeyvalue_test.go +++ b/internal/common/pgkeyvalue/pgkeyvalue_test.go @@ -1,7 +1,6 @@ package pgkeyvalue import ( - "context" "testing" "time" @@ -11,11 +10,12 @@ import ( "golang.org/x/exp/maps" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/lookout/testutil" ) func TestLoadStore(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { kvStore, err := New(ctx, db, "cachetable") @@ -47,7 +47,7 @@ func TestLoadStore(t *testing.T) { } func TestCleanup(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { baseTime := time.Now() diff --git a/internal/common/pulsarutils/async.go b/internal/common/pulsarutils/async.go index 8f71781d558..9040eed5fe9 100644 --- a/internal/common/pulsarutils/async.go +++ b/internal/common/pulsarutils/async.go @@ -7,11 +7,11 @@ import ( "sync" "time" - commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/apache/pulsar-client-go/pulsar" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" + commonmetrics "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/util" ) @@ -36,7 +36,7 @@ type ConsumerMessage struct { var msgLogger = logrus.NewEntry(logrus.StandardLogger()) func Receive( - ctx context.Context, + ctx *armadacontext.Context, consumer pulsar.Consumer, receiveTimeout time.Duration, backoffTime time.Duration, @@ -76,7 +76,7 @@ func Receive( return default: // Get a message from Pulsar, which consists of a sequence of events (i.e., state transitions). - ctxWithTimeout, cancel := context.WithTimeout(ctx, receiveTimeout) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, receiveTimeout) msg, err := consumer.Receive(ctxWithTimeout) if errors.Is(err, context.DeadlineExceeded) { msgLogger.Debugf("No message received") @@ -109,7 +109,7 @@ func Receive( // Ack will ack all pulsar messages coming in on the msgs channel. The incoming messages contain a consumer id which // corresponds to the index of the consumer that should be used to perform the ack. In theory, the acks could be done // in parallel, however its unlikely that they will be a performance bottleneck -func Ack(ctx context.Context, consumers []pulsar.Consumer, msgs chan []*ConsumerMessageId, backoffTime time.Duration, wg *sync.WaitGroup) { +func Ack(ctx *armadacontext.Context, consumers []pulsar.Consumer, msgs chan []*ConsumerMessageId, backoffTime time.Duration, wg *sync.WaitGroup) { for msg := range msgs { for _, id := range msg { if id.ConsumerId < 0 || id.ConsumerId >= len(consumers) { diff --git a/internal/common/pulsarutils/async_test.go b/internal/common/pulsarutils/async_test.go index d47151c660d..bb8739254df 100644 --- a/internal/common/pulsarutils/async_test.go +++ b/internal/common/pulsarutils/async_test.go @@ -1,16 +1,16 @@ package pulsarutils import ( - ctx "context" + "context" "sync" "testing" "time" - "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/apache/pulsar-client-go/pulsar" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" + + "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/ingest/metrics" ) var m = metrics.NewMetrics("test_pulsarutils_") @@ -46,8 +46,8 @@ func TestReceive(t *testing.T) { consumer := &mockConsumer{ msgs: msgs, } - context, cancel := ctx.WithCancel(ctx.Background()) - outputChan := Receive(context, consumer, 10*time.Millisecond, 10*time.Millisecond, m) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) + outputChan := Receive(ctx, consumer, 10*time.Millisecond, 10*time.Millisecond, m) var receivedMsgs []pulsar.Message wg := sync.WaitGroup{} @@ -71,7 +71,7 @@ func TestAcks(t *testing.T) { consumers := []pulsar.Consumer{&mockConsumer} wg := sync.WaitGroup{} wg.Add(1) - go Ack(ctx.Background(), consumers, input, 1*time.Second, &wg) + go Ack(armadacontext.Background(), consumers, input, 1*time.Second, &wg) input <- []*ConsumerMessageId{ {NewMessageId(1), 0, 0}, {NewMessageId(2), 0, 0}, } diff --git a/internal/common/pulsarutils/eventsequence.go b/internal/common/pulsarutils/eventsequence.go index 49325bd0b2b..3750a1b11e8 100644 --- a/internal/common/pulsarutils/eventsequence.go +++ b/internal/common/pulsarutils/eventsequence.go @@ -1,24 +1,23 @@ package pulsarutils import ( - "context" "sync/atomic" - "github.com/armadaproject/armada/internal/common/schedulers" - "github.com/apache/pulsar-client-go/pulsar" "github.com/gogo/protobuf/proto" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/requestid" + "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" ) // CompactAndPublishSequences reduces the number of sequences to the smallest possible, // while respecting per-job set ordering and max Pulsar message size, and then publishes to Pulsar. -func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint, scheduler schedulers.Scheduler) error { +func CompactAndPublishSequences(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint, scheduler schedulers.Scheduler) error { // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than maxMessageSizeInBytes. sequences = eventutil.CompactEventSequences(sequences) @@ -38,7 +37,7 @@ func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.E // and // eventutil.LimitSequencesByteSize(sequences, int(srv.MaxAllowedMessageSize)) // before passing to this function. -func PublishSequences(ctx context.Context, producer pulsar.Producer, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { +func PublishSequences(ctx *armadacontext.Context, producer pulsar.Producer, sequences []*armadaevents.EventSequence, scheduler schedulers.Scheduler) error { // Incoming gRPC requests are annotated with a unique id. // Pass this id through the log by adding it to the Pulsar message properties. requestId := requestid.FromContextOrMissing(ctx) diff --git a/internal/common/pulsarutils/eventsequence_test.go b/internal/common/pulsarutils/eventsequence_test.go index 0613a9f3462..0832195beac 100644 --- a/internal/common/pulsarutils/eventsequence_test.go +++ b/internal/common/pulsarutils/eventsequence_test.go @@ -9,19 +9,20 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" ) func TestPublishSequences_SendAsyncErr(t *testing.T) { producer := &mockProducer{} - err := PublishSequences(context.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) + err := PublishSequences(armadacontext.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.NoError(t, err) producer = &mockProducer{ sendAsyncErr: errors.New("sendAsyncErr"), } - err = PublishSequences(context.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) + err = PublishSequences(armadacontext.Background(), producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.ErrorIs(t, err, producer.sendAsyncErr) } @@ -29,7 +30,7 @@ func TestPublishSequences_RespectTimeout(t *testing.T) { producer := &mockProducer{ sendAsyncDuration: 1 * time.Second, } - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond) defer cancel() err := PublishSequences(ctx, producer, []*armadaevents.EventSequence{{}}, schedulers.Pulsar) assert.ErrorIs(t, err, context.DeadlineExceeded) diff --git a/internal/common/startup.go b/internal/common/startup.go index e14fa5a21a7..eebc23cdd3d 100644 --- a/internal/common/startup.go +++ b/internal/common/startup.go @@ -1,7 +1,6 @@ package common import ( - "context" "fmt" "net/http" "os" @@ -18,6 +17,7 @@ import ( "github.com/spf13/viper" "github.com/weaveworks/promrus" + "github.com/armadaproject/armada/internal/common/armadacontext" commonconfig "github.com/armadaproject/armada/internal/common/config" "github.com/armadaproject/armada/internal/common/logging" ) @@ -159,7 +159,7 @@ func ServeHttp(port uint16, mux http.Handler) (shutdown func()) { // TODO There's no need for this function to panic, since the main goroutine will exit. // Instead, just log an error. return func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() log.Printf("Stopping http server listening on %d", port) e := srv.Shutdown(ctx) diff --git a/internal/common/util/context.go b/internal/common/util/context.go index c96b4f0adee..1f6fa6519f4 100644 --- a/internal/common/util/context.go +++ b/internal/common/util/context.go @@ -1,11 +1,12 @@ package util import ( - "context" "time" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func CloseToDeadline(ctx context.Context, tolerance time.Duration) bool { +func CloseToDeadline(ctx *armadacontext.Context, tolerance time.Duration) bool { deadline, exists := ctx.Deadline() return exists && deadline.Before(time.Now().Add(tolerance)) } diff --git a/internal/common/util/retry.go b/internal/common/util/retry.go index 9f178c037d8..c688614e63e 100644 --- a/internal/common/util/retry.go +++ b/internal/common/util/retry.go @@ -1,8 +1,10 @@ package util -import "golang.org/x/net/context" +import ( + "github.com/armadaproject/armada/internal/common/armadacontext" +) -func RetryUntilSuccess(ctx context.Context, performAction func() error, onError func(error)) { +func RetryUntilSuccess(ctx *armadacontext.Context, performAction func() error, onError func(error)) { for { select { case <-ctx.Done(): diff --git a/internal/common/util/retry_test.go b/internal/common/util/retry_test.go index 43180ac6f39..2ad6ea4b300 100644 --- a/internal/common/util/retry_test.go +++ b/internal/common/util/retry_test.go @@ -1,16 +1,17 @@ package util import ( - "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) func TestRetryDoesntSpin(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( @@ -30,7 +31,7 @@ func TestRetryDoesntSpin(t *testing.T) { } func TestRetryCancel(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( @@ -61,7 +62,7 @@ func TestSucceedsAfterFailures(t *testing.T) { errorCount := 0 - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 1*time.Second) defer cancel() RetryUntilSuccess( diff --git a/internal/eventingester/convert/conversions.go b/internal/eventingester/convert/conversions.go index cab978e5812..fbb66a0c481 100644 --- a/internal/eventingester/convert/conversions.go +++ b/internal/eventingester/convert/conversions.go @@ -1,12 +1,11 @@ package convert import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -30,7 +29,7 @@ func NewEventConverter(compressor compress.Compressor, maxMessageBatchSize uint, } } -func (ec *EventConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.BatchUpdate { +func (ec *EventConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.BatchUpdate { // Remove all groups as they are potentially quite large for _, es := range sequencesWithIds.EventSequences { es.Groups = nil diff --git a/internal/eventingester/convert/conversions_test.go b/internal/eventingester/convert/conversions_test.go index c716f84815a..24ff9013733 100644 --- a/internal/eventingester/convert/conversions_test.go +++ b/internal/eventingester/convert/conversions_test.go @@ -1,7 +1,6 @@ package convert import ( - "context" "math/rand" "testing" "time" @@ -11,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -55,7 +55,7 @@ var cancelled = &armadaevents.EventSequence_Event{ func TestSingle(t *testing.T) { msg := NewMsg(jobRunSucceeded) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) expectedSequence := armadaevents.EventSequence{ Events: []*armadaevents.EventSequence_Event{jobRunSucceeded}, } @@ -72,7 +72,7 @@ func TestSingle(t *testing.T) { func TestMultiple(t *testing.T) { msg := NewMsg(cancelled, jobRunSucceeded) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) expectedSequence := armadaevents.EventSequence{ Events: []*armadaevents.EventSequence_Event{cancelled, jobRunSucceeded}, } @@ -113,7 +113,7 @@ func TestCancelled(t *testing.T) { }, }) converter := simpleEventConverter() - batchUpdate := converter.Convert(context.Background(), msg) + batchUpdate := converter.Convert(armadacontext.Background(), msg) assert.Equal(t, 1, len(batchUpdate.Events)) event := batchUpdate.Events[0] es, err := extractEventSeq(event.Event) diff --git a/internal/eventingester/store/eventstore.go b/internal/eventingester/store/eventstore.go index 2f9dc7555a2..981e8460c16 100644 --- a/internal/eventingester/store/eventstore.go +++ b/internal/eventingester/store/eventstore.go @@ -1,7 +1,6 @@ package store import ( - "context" "regexp" "time" @@ -9,6 +8,7 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/eventingester/configuration" "github.com/armadaproject/armada/internal/eventingester/model" @@ -39,7 +39,7 @@ func NewRedisEventStore(db redis.UniversalClient, eventRetention configuration.E } } -func (repo *RedisEventStore) Store(ctx context.Context, update *model.BatchUpdate) error { +func (repo *RedisEventStore) Store(ctx *armadacontext.Context, update *model.BatchUpdate) error { if len(update.Events) == 0 { return nil } diff --git a/internal/eventingester/store/eventstore_test.go b/internal/eventingester/store/eventstore_test.go index 3327a4fff95..1584b56ba15 100644 --- a/internal/eventingester/store/eventstore_test.go +++ b/internal/eventingester/store/eventstore_test.go @@ -1,13 +1,13 @@ package store import ( - "context" "testing" "time" "github.com/go-redis/redis" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/eventingester/configuration" "github.com/armadaproject/armada/internal/eventingester/model" ) @@ -29,7 +29,7 @@ func TestReportEvents(t *testing.T) { }, } - err := r.Store(context.Background(), update) + err := r.Store(armadacontext.Background(), update) assert.NoError(t, err) read1, err := ReadEvent(r.db, "testQueue", "testJobset") diff --git a/internal/executor/application.go b/internal/executor/application.go index 6a15c0f9414..3cb4db15af0 100644 --- a/internal/executor/application.go +++ b/internal/executor/application.go @@ -1,7 +1,6 @@ package executor import ( - "context" "fmt" "net/http" "os" @@ -14,10 +13,10 @@ import ( grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/cluster" "github.com/armadaproject/armada/internal/common/etcdhealth" "github.com/armadaproject/armada/internal/common/healthmonitor" @@ -41,7 +40,7 @@ import ( "github.com/armadaproject/armada/pkg/executorapi" ) -func StartUp(ctx context.Context, log *logrus.Entry, config configuration.ExecutorConfiguration) (func(), *sync.WaitGroup) { +func StartUp(ctx *armadacontext.Context, log *logrus.Entry, config configuration.ExecutorConfiguration) (func(), *sync.WaitGroup) { err := validateConfig(config) if err != nil { log.Errorf("Invalid config: %s", err) @@ -59,7 +58,7 @@ func StartUp(ctx context.Context, log *logrus.Entry, config configuration.Execut } // Create an errgroup to run services in. - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) // Setup etcd health monitoring. etcdClusterHealthMonitoringByName := make(map[string]healthmonitor.HealthMonitor, len(config.Kubernetes.Etcd.EtcdClustersHealthMonitoring)) diff --git a/internal/executor/context/cluster_context.go b/internal/executor/context/cluster_context.go index 79619ea06fd..555303fe9a3 100644 --- a/internal/executor/context/cluster_context.go +++ b/internal/executor/context/cluster_context.go @@ -1,7 +1,6 @@ package context import ( - "context" "encoding/json" "fmt" "time" @@ -26,6 +25,7 @@ import ( "k8s.io/kubelet/pkg/apis/stats/v1alpha1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/cluster" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -50,7 +50,7 @@ type ClusterContext interface { GetActiveBatchPods() ([]*v1.Pod, error) GetNodes() ([]*v1.Node, error) GetNode(nodeName string) (*v1.Node, error) - GetNodeStatsSummary(context.Context, *v1.Node) (*v1alpha1.Summary, error) + GetNodeStatsSummary(*armadacontext.Context, *v1.Node) (*v1alpha1.Summary, error) GetPodEvents(pod *v1.Pod) ([]*v1.Event, error) GetServices(pod *v1.Pod) ([]*v1.Service, error) GetIngresses(pod *v1.Pod) ([]*networking.Ingress, error) @@ -223,7 +223,7 @@ func (c *KubernetesClusterContext) GetNode(nodeName string) (*v1.Node, error) { return c.nodeInformer.Lister().Get(nodeName) } -func (c *KubernetesClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *KubernetesClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { request := c.kubernetesClient. CoreV1(). RESTClient(). @@ -253,7 +253,7 @@ func (c *KubernetesClusterContext) SubmitPod(pod *v1.Pod, owner string, ownerGro return nil, err } - returnedPod, err := ownerClient.CoreV1().Pods(pod.Namespace).Create(context.Background(), pod, metav1.CreateOptions{}) + returnedPod, err := ownerClient.CoreV1().Pods(pod.Namespace).Create(armadacontext.Background(), pod, metav1.CreateOptions{}) if err != nil { c.submittedPods.Delete(util.ExtractPodKey(pod)) } @@ -261,11 +261,11 @@ func (c *KubernetesClusterContext) SubmitPod(pod *v1.Pod, owner string, ownerGro } func (c *KubernetesClusterContext) SubmitService(service *v1.Service) (*v1.Service, error) { - return c.kubernetesClient.CoreV1().Services(service.Namespace).Create(context.Background(), service, metav1.CreateOptions{}) + return c.kubernetesClient.CoreV1().Services(service.Namespace).Create(armadacontext.Background(), service, metav1.CreateOptions{}) } func (c *KubernetesClusterContext) SubmitIngress(ingress *networking.Ingress) (*networking.Ingress, error) { - return c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Create(context.Background(), ingress, metav1.CreateOptions{}) + return c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Create(armadacontext.Background(), ingress, metav1.CreateOptions{}) } func (c *KubernetesClusterContext) AddAnnotation(pod *v1.Pod, annotations map[string]string) error { @@ -280,7 +280,7 @@ func (c *KubernetesClusterContext) AddAnnotation(pod *v1.Pod, annotations map[st } _, err = c.kubernetesClient.CoreV1(). Pods(pod.Namespace). - Patch(context.Background(), pod.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) + Patch(armadacontext.Background(), pod.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) if err != nil { return err } @@ -299,7 +299,7 @@ func (c *KubernetesClusterContext) AddClusterEventAnnotation(event *v1.Event, an } _, err = c.kubernetesClient.CoreV1(). Events(event.Namespace). - Patch(context.Background(), event.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) + Patch(armadacontext.Background(), event.Name, types.StrategicMergePatchType, patchBytes, metav1.PatchOptions{}) if err != nil { return err } @@ -318,7 +318,7 @@ func (c *KubernetesClusterContext) DeletePodWithCondition(pod *v1.Pod, condition return err } // Get latest pod state - bypassing cache - timeout, cancel := context.WithTimeout(context.Background(), time.Second*10) + timeout, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Second*10) defer cancel() currentPod, err = c.kubernetesClient.CoreV1().Pods(currentPod.Namespace).Get(timeout, currentPod.Name, metav1.GetOptions{}) if err != nil { @@ -368,7 +368,7 @@ func (c *KubernetesClusterContext) DeletePods(pods []*v1.Pod) { func (c *KubernetesClusterContext) DeleteService(service *v1.Service) error { deleteOptions := createDeleteOptions() - err := c.kubernetesClient.CoreV1().Services(service.Namespace).Delete(context.Background(), service.Name, deleteOptions) + err := c.kubernetesClient.CoreV1().Services(service.Namespace).Delete(armadacontext.Background(), service.Name, deleteOptions) if err != nil && k8s_errors.IsNotFound(err) { return nil } @@ -377,7 +377,7 @@ func (c *KubernetesClusterContext) DeleteService(service *v1.Service) error { func (c *KubernetesClusterContext) DeleteIngress(ingress *networking.Ingress) error { deleteOptions := createDeleteOptions() - err := c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Delete(context.Background(), ingress.Name, deleteOptions) + err := c.kubernetesClient.NetworkingV1().Ingresses(ingress.Namespace).Delete(armadacontext.Background(), ingress.Name, deleteOptions) if err != nil && k8s_errors.IsNotFound(err) { return nil } @@ -386,7 +386,7 @@ func (c *KubernetesClusterContext) DeleteIngress(ingress *networking.Ingress) er func (c *KubernetesClusterContext) ProcessPodsToDelete() { pods := c.podsToDelete.GetAll() - util.ProcessItemsWithThreadPool(context.Background(), c.deleteThreadCount, pods, func(podToDelete *v1.Pod) { + util.ProcessItemsWithThreadPool(armadacontext.Background(), c.deleteThreadCount, pods, func(podToDelete *v1.Pod) { if podToDelete == nil { return } @@ -438,7 +438,7 @@ func (c *KubernetesClusterContext) doDelete(pod *v1.Pod, force bool) { } func (c *KubernetesClusterContext) deletePod(pod *v1.Pod, deleteOptions metav1.DeleteOptions) error { - return c.kubernetesClient.CoreV1().Pods(pod.Namespace).Delete(context.Background(), pod.Name, deleteOptions) + return c.kubernetesClient.CoreV1().Pods(pod.Namespace).Delete(armadacontext.Background(), pod.Name, deleteOptions) } func (c *KubernetesClusterContext) markForDeletion(pod *v1.Pod) (*v1.Pod, error) { diff --git a/internal/executor/context/cluster_context_test.go b/internal/executor/context/cluster_context_test.go index d1836e82168..b382cd0e690 100644 --- a/internal/executor/context/cluster_context_test.go +++ b/internal/executor/context/cluster_context_test.go @@ -1,7 +1,6 @@ package context import ( - ctx "context" "encoding/json" "errors" "testing" @@ -23,6 +22,7 @@ import ( clientTesting "k8s.io/client-go/testing" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" "github.com/armadaproject/armada/internal/executor/domain" @@ -699,7 +699,7 @@ func TestKubernetesClusterContext_GetNodes(t *testing.T) { }, } - _, err := client.CoreV1().Nodes().Create(ctx.Background(), node, metav1.CreateOptions{}) + _, err := client.CoreV1().Nodes().Create(armadacontext.Background(), node, metav1.CreateOptions{}) assert.Nil(t, err) nodeFound := waitForCondition(func() bool { diff --git a/internal/executor/context/fake/sync_cluster_context.go b/internal/executor/context/fake/sync_cluster_context.go index 7a8d26797d0..d4a178920d0 100644 --- a/internal/executor/context/fake/sync_cluster_context.go +++ b/internal/executor/context/fake/sync_cluster_context.go @@ -1,7 +1,6 @@ package fake import ( - "context" "errors" "fmt" @@ -11,6 +10,7 @@ import ( "k8s.io/client-go/tools/cache" "k8s.io/kubelet/pkg/apis/stats/v1alpha1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/domain" ) @@ -132,7 +132,7 @@ func (c *SyncFakeClusterContext) GetClusterPool() string { return "pool" } -func (c *SyncFakeClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *SyncFakeClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { return &v1alpha1.Summary{}, nil } diff --git a/internal/executor/fake/context/context.go b/internal/executor/fake/context/context.go index 0cee687458b..906c23fe85f 100644 --- a/internal/executor/fake/context/context.go +++ b/internal/executor/fake/context/context.go @@ -1,7 +1,6 @@ package context import ( - "context" "fmt" "math/rand" "regexp" @@ -23,6 +22,7 @@ import ( "k8s.io/client-go/tools/cache" "k8s.io/kubelet/pkg/apis/stats/v1alpha1" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -314,7 +314,7 @@ func (c *FakeClusterContext) GetClusterPool() string { return c.pool } -func (c *FakeClusterContext) GetNodeStatsSummary(ctx context.Context, node *v1.Node) (*v1alpha1.Summary, error) { +func (c *FakeClusterContext) GetNodeStatsSummary(ctx *armadacontext.Context, node *v1.Node) (*v1alpha1.Summary, error) { return &v1alpha1.Summary{}, nil } diff --git a/internal/executor/job/job_context.go b/internal/executor/job/job_context.go index 3cc8b36f2b3..bcc5526ce2e 100644 --- a/internal/executor/job/job_context.go +++ b/internal/executor/job/job_context.go @@ -1,7 +1,6 @@ package job import ( - "context" "fmt" "sync" "time" @@ -10,6 +9,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/client-go/tools/cache" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/podchecks" "github.com/armadaproject/armada/internal/executor/util" @@ -149,7 +149,7 @@ func (c *ClusterJobContext) AddAnnotation(jobs []*RunningJob, annotations map[st } } - util.ProcessItemsWithThreadPool(context.Background(), c.updateThreadCount, podsToAnnotate, + util.ProcessItemsWithThreadPool(armadacontext.Background(), c.updateThreadCount, podsToAnnotate, func(pod *v1.Pod) { err := c.clusterContext.AddAnnotation(pod, annotations) if err != nil { diff --git a/internal/executor/job/processors/preempt_runs.go b/internal/executor/job/processors/preempt_runs.go index 9e48adb71a6..c296f7b75f7 100644 --- a/internal/executor/job/processors/preempt_runs.go +++ b/internal/executor/job/processors/preempt_runs.go @@ -1,13 +1,13 @@ package processors import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -46,7 +46,7 @@ func (j *RunPreemptedProcessor) Run() { }) runPodInfos := createRunPodInfos(runsToCancel, managedPods) - util.ProcessItemsWithThreadPool(context.Background(), 20, runPodInfos, + util.ProcessItemsWithThreadPool(armadacontext.Background(), 20, runPodInfos, func(runInfo *runPodInfo) { pod := runInfo.Pod if pod == nil { diff --git a/internal/executor/job/processors/remove_runs.go b/internal/executor/job/processors/remove_runs.go index 83038d8c1e1..37942110605 100644 --- a/internal/executor/job/processors/remove_runs.go +++ b/internal/executor/job/processors/remove_runs.go @@ -1,12 +1,12 @@ package processors import ( - "context" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -37,7 +37,7 @@ func (j *RemoveRunProcessor) Run() { }) runPodInfos := createRunPodInfos(runsToCancel, managedPods) - util.ProcessItemsWithThreadPool(context.Background(), 20, runPodInfos, + util.ProcessItemsWithThreadPool(armadacontext.Background(), 20, runPodInfos, func(runInfo *runPodInfo) { pod := runInfo.Pod if pod == nil { diff --git a/internal/executor/reporter/event_sender.go b/internal/executor/reporter/event_sender.go index 9dd42a03f9d..d9afe0fa48b 100644 --- a/internal/executor/reporter/event_sender.go +++ b/internal/executor/reporter/event_sender.go @@ -1,13 +1,12 @@ package reporter import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" @@ -63,7 +62,7 @@ func (eventSender *ExecutorApiEventSender) SendEvents(events []EventMessage) err } for _, eventList := range eventLists { - _, err = eventSender.eventClient.ReportEvents(context.Background(), eventList) + _, err = eventSender.eventClient.ReportEvents(armadacontext.Background(), eventList) if err != nil { return err } diff --git a/internal/executor/reporter/event_sender_test.go b/internal/executor/reporter/event_sender_test.go index 1c91d1cd6f0..08e60daa521 100644 --- a/internal/executor/reporter/event_sender_test.go +++ b/internal/executor/reporter/event_sender_test.go @@ -205,13 +205,13 @@ func newFakeExecutorApiClient() *fakeExecutorApiClient { } } -func (fakeClient *fakeExecutorApiClient) LeaseJobRuns(ctx context.Context, opts ...grpc.CallOption) (executorapi.ExecutorApi_LeaseJobRunsClient, error) { +func (fakeClient *fakeExecutorApiClient) LeaseJobRuns(_ context.Context, opts ...grpc.CallOption) (executorapi.ExecutorApi_LeaseJobRunsClient, error) { // Not implemented return nil, nil } // Reports job run events to the scheduler -func (fakeClient *fakeExecutorApiClient) ReportEvents(ctx context.Context, in *executorapi.EventList, opts ...grpc.CallOption) (*types.Empty, error) { +func (fakeClient *fakeExecutorApiClient) ReportEvents(_ context.Context, in *executorapi.EventList, opts ...grpc.CallOption) (*types.Empty, error) { fakeClient.reportedEvents = append(fakeClient.reportedEvents, in) return nil, nil } diff --git a/internal/executor/service/job_lease.go b/internal/executor/service/job_lease.go index d8165c32a1b..1b18fc0c9d2 100644 --- a/internal/executor/service/job_lease.go +++ b/internal/executor/service/job_lease.go @@ -10,7 +10,6 @@ import ( grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" v1 "k8s.io/api/core/v1" @@ -18,6 +17,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaresource "github.com/armadaproject/armada/internal/common/resource" commonUtil "github.com/armadaproject/armada/internal/common/util" @@ -111,10 +111,10 @@ func (jobLeaseService *JobLeaseService) requestJobLeases(leaseRequest *api.Strea // Setup a bidirectional gRPC stream. // The server sends jobs over this stream. // The executor sends back acks to indicate which jobs were successfully received. - ctx := context.Background() + ctx := armadacontext.Background() var cancel context.CancelFunc if jobLeaseService.jobLeaseRequestTimeout != 0 { - ctx, cancel = context.WithTimeout(ctx, jobLeaseService.jobLeaseRequestTimeout) + ctx, cancel = armadacontext.WithTimeout(ctx, jobLeaseService.jobLeaseRequestTimeout) defer cancel() } stream, err := jobLeaseService.queueClient.StreamingLeaseJobs(ctx, grpc_retry.Disable(), grpc.UseCompressor(gzip.Name)) @@ -137,7 +137,7 @@ func (jobLeaseService *JobLeaseService) requestJobLeases(leaseRequest *api.Strea var numJobs uint32 jobs := make([]*api.Job, 0) ch := make(chan *api.StreamingJobLease, 10) - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) g.Go(func() error { // Close channel to ensure sending goroutine exits. defer close(ch) diff --git a/internal/executor/service/job_manager.go b/internal/executor/service/job_manager.go index 4b8b1cfe016..496440d0538 100644 --- a/internal/executor/service/job_manager.go +++ b/internal/executor/service/job_manager.go @@ -1,13 +1,13 @@ package service import ( - "context" "fmt" "time" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" context2 "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" "github.com/armadaproject/armada/internal/executor/job" @@ -75,7 +75,7 @@ func (m *JobManager) ManageJobLeases() { } } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Minute*2) defer cancel() m.handlePodIssues(ctx, jobs) } @@ -108,7 +108,7 @@ func (m *JobManager) reportTerminated(pods []*v1.Pod) { } } -func (m *JobManager) handlePodIssues(ctx context.Context, allRunningJobs []*job.RunningJob) { +func (m *JobManager) handlePodIssues(ctx *armadacontext.Context, allRunningJobs []*job.RunningJob) { util.ProcessItemsWithThreadPool(ctx, 20, allRunningJobs, m.handlePodIssue) } diff --git a/internal/executor/service/job_requester.go b/internal/executor/service/job_requester.go index 217f279639e..53cf83c49a6 100644 --- a/internal/executor/service/job_requester.go +++ b/internal/executor/service/job_requester.go @@ -1,12 +1,12 @@ package service import ( - "context" "time" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/slices" util2 "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -56,7 +56,7 @@ func (r *JobRequester) RequestJobsRuns() { log.Errorf("Failed to create lease request because %s", err) return } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() leaseResponse, err := r.leaseRequester.LeaseJobRuns(ctx, leaseRequest) if err != nil { diff --git a/internal/executor/service/job_requester_test.go b/internal/executor/service/job_requester_test.go index 532e7e4fb0e..f7e3fcbc5b7 100644 --- a/internal/executor/service/job_requester_test.go +++ b/internal/executor/service/job_requester_test.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "testing" @@ -11,6 +10,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/executor/configuration" @@ -275,7 +275,7 @@ type StubLeaseRequester struct { LeaseJobRunLeaseResponse *LeaseResponse } -func (s *StubLeaseRequester) LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) { +func (s *StubLeaseRequester) LeaseJobRuns(_ *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) { s.ReceivedLeaseRequests = append(s.ReceivedLeaseRequests, request) return s.LeaseJobRunLeaseResponse, s.LeaseJobRunError } diff --git a/internal/executor/service/lease_requester.go b/internal/executor/service/lease_requester.go index 36a29f6e4f2..dc4976d84b5 100644 --- a/internal/executor/service/lease_requester.go +++ b/internal/executor/service/lease_requester.go @@ -10,6 +10,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaresource "github.com/armadaproject/armada/internal/common/resource" clusterContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/pkg/api" @@ -31,7 +32,7 @@ type LeaseResponse struct { } type LeaseRequester interface { - LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) + LeaseJobRuns(ctx *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) } type JobLeaseRequester struct { @@ -52,7 +53,7 @@ func NewJobLeaseRequester( } } -func (requester *JobLeaseRequester) LeaseJobRuns(ctx context.Context, request *LeaseRequest) (*LeaseResponse, error) { +func (requester *JobLeaseRequester) LeaseJobRuns(ctx *armadacontext.Context, request *LeaseRequest) (*LeaseResponse, error) { stream, err := requester.executorApiClient.LeaseJobRuns(ctx, grpcretry.Disable(), grpc.UseCompressor(gzip.Name)) if err != nil { return nil, err diff --git a/internal/executor/service/lease_requester_test.go b/internal/executor/service/lease_requester_test.go index 3f09cf450a7..f6314876c9f 100644 --- a/internal/executor/service/lease_requester_test.go +++ b/internal/executor/service/lease_requester_test.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "io" "testing" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/mocks" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/executor/context/fake" @@ -39,7 +39,7 @@ var ( ) func TestLeaseJobRuns(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() tests := map[string]struct { leaseMessages []*executorapi.JobRunLease @@ -87,7 +87,7 @@ func TestLeaseJobRuns(t *testing.T) { } func TestLeaseJobRuns_Send(t *testing.T) { - shortCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + shortCtx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() leaseRequest := &LeaseRequest{ @@ -126,7 +126,7 @@ func TestLeaseJobRuns_Send(t *testing.T) { func TestLeaseJobRuns_HandlesNoEndMarkerMessage(t *testing.T) { leaseMessages := []*executorapi.JobRunLease{lease1, lease2} - shortCtx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + shortCtx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 200*time.Millisecond) defer cancel() jobRequester, mockExecutorApiClient, mockStream := setup(t) @@ -146,7 +146,7 @@ func TestLeaseJobRuns_HandlesNoEndMarkerMessage(t *testing.T) { } func TestLeaseJobRuns_Error(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second) defer cancel() tests := map[string]struct { streamError bool diff --git a/internal/executor/service/pod_issue_handler.go b/internal/executor/service/pod_issue_handler.go index b98980df01b..57b323e7146 100644 --- a/internal/executor/service/pod_issue_handler.go +++ b/internal/executor/service/pod_issue_handler.go @@ -1,7 +1,6 @@ package service import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "k8s.io/client-go/tools/cache" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/executor/configuration" executorContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/job" @@ -159,7 +159,7 @@ func (p *IssueHandler) HandlePodIssues() { }) p.detectPodIssues(managedPods) p.detectReconciliationIssues(managedPods) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*2) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Minute*2) defer cancel() p.handleKnownIssues(ctx, managedPods) } @@ -225,7 +225,7 @@ func (p *IssueHandler) detectPodIssues(allManagedPods []*v1.Pod) { } } -func (p *IssueHandler) handleKnownIssues(ctx context.Context, allManagedPods []*v1.Pod) { +func (p *IssueHandler) handleKnownIssues(ctx *armadacontext.Context, allManagedPods []*v1.Pod) { // Make issues from pods + issues issues := createIssues(allManagedPods, p.knownPodIssues) util.ProcessItemsWithThreadPool(ctx, 20, issues, p.handleRunIssue) diff --git a/internal/executor/util/process.go b/internal/executor/util/process.go index cc4da52d9a2..a38c316b5fa 100644 --- a/internal/executor/util/process.go +++ b/internal/executor/util/process.go @@ -1,13 +1,13 @@ package util import ( - "context" "sync" + "github.com/armadaproject/armada/internal/common/armadacontext" commonUtil "github.com/armadaproject/armada/internal/common/util" ) -func ProcessItemsWithThreadPool[K any](ctx context.Context, maxThreadCount int, itemsToProcess []K, processFunc func(K)) { +func ProcessItemsWithThreadPool[K any](ctx *armadacontext.Context, maxThreadCount int, itemsToProcess []K, processFunc func(K)) { wg := &sync.WaitGroup{} processChannel := make(chan K) @@ -24,7 +24,7 @@ func ProcessItemsWithThreadPool[K any](ctx context.Context, maxThreadCount int, wg.Wait() } -func poolWorker[K any](ctx context.Context, wg *sync.WaitGroup, podsToProcess chan K, processFunc func(K)) { +func poolWorker[K any](ctx *armadacontext.Context, wg *sync.WaitGroup, podsToProcess chan K, processFunc func(K)) { defer wg.Done() for pod := range podsToProcess { diff --git a/internal/executor/util/process_test.go b/internal/executor/util/process_test.go index cfdb237dea9..f6995106c70 100644 --- a/internal/executor/util/process_test.go +++ b/internal/executor/util/process_test.go @@ -1,12 +1,13 @@ package util import ( - "context" "sync" "testing" "time" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) func TestProcessItemsWithThreadPool(t *testing.T) { @@ -14,7 +15,7 @@ func TestProcessItemsWithThreadPool(t *testing.T) { output := []string{} outputMutex := &sync.Mutex{} - ProcessItemsWithThreadPool(context.Background(), 2, input, func(item string) { + ProcessItemsWithThreadPool(armadacontext.Background(), 2, input, func(item string) { outputMutex.Lock() defer outputMutex.Unlock() output = append(output, item) @@ -28,7 +29,7 @@ func TestProcessItemsWithThreadPool_HandlesContextCancellation(t *testing.T) { output := []string{} outputMutex := &sync.Mutex{} - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond*100) defer cancel() ProcessItemsWithThreadPool(ctx, 2, input, func(item string) { diff --git a/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go b/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go index 258d3740942..d7b32d01c7e 100644 --- a/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go +++ b/internal/executor/utilisation/pod_utilisation_kubelet_metrics.go @@ -1,7 +1,6 @@ package utilisation import ( - "context" "sync" "time" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" clusterContext "github.com/armadaproject/armada/internal/executor/context" "github.com/armadaproject/armada/internal/executor/domain" ) @@ -28,7 +28,7 @@ func (m *podUtilisationKubeletMetrics) fetch(nodes []*v1.Node, podNameToUtilisat wg.Add(1) go func(node *v1.Node) { defer wg.Done() - ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*15) + ctx, cancelFunc := armadacontext.WithTimeout(armadacontext.Background(), time.Second*15) defer cancelFunc() summary, err := clusterContext.GetNodeStatsSummary(ctx, node) if err != nil { diff --git a/internal/lookout/repository/job_pruner.go b/internal/lookout/repository/job_pruner.go index 6c9b92e60df..a77f1657007 100644 --- a/internal/lookout/repository/job_pruner.go +++ b/internal/lookout/repository/job_pruner.go @@ -1,12 +1,13 @@ package repository import ( - "context" "database/sql" "fmt" "time" log "github.com/sirupsen/logrus" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const postgresFormat = "2006-01-02 15:04:05.000000" @@ -22,7 +23,7 @@ const postgresFormat = "2006-01-02 15:04:05.000000" // For performance reasons we don't use a transaction here and so an error may indicate that // Some jobs were deleted. func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 120*time.Second) defer cancel() // This would be much better done as a proper statement with parameters, but postgres doesn't support @@ -30,7 +31,7 @@ func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { queryText := fmt.Sprintf(` CREATE TEMP TABLE rows_to_delete AS (SELECT job_id FROM job WHERE submitted < '%v' OR submitted IS NULL); CREATE TEMP TABLE batch (job_id varchar(32)); - + DO $do$ DECLARE @@ -52,7 +53,7 @@ func DeleteOldJobs(db *sql.DB, batchSizeLimit int, cutoff time.Time) error { END LOOP; END; $do$; - + DROP TABLE rows_to_delete; DROP TABLE batch; `, cutoff.Format(postgresFormat), batchSizeLimit) diff --git a/internal/lookout/repository/job_sets.go b/internal/lookout/repository/job_sets.go index 70bd187f866..28b60a48179 100644 --- a/internal/lookout/repository/job_sets.go +++ b/internal/lookout/repository/job_sets.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "time" @@ -9,6 +8,7 @@ import ( "github.com/doug-martin/goqu/v9/exp" "github.com/gogo/protobuf/types" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -38,7 +38,7 @@ type jobSetCountsRow struct { QueuedStatsQ3 sql.NullTime `db:"queued_q3"` } -func (r *SQLJobRepository) GetJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) { +func (r *SQLJobRepository) GetJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) { rows, err := r.queryJobSetInfos(ctx, opts) if err != nil { return nil, err @@ -47,7 +47,7 @@ func (r *SQLJobRepository) GetJobSetInfos(ctx context.Context, opts *lookout.Get return r.rowsToJobSets(rows, opts.Queue), nil } -func (r *SQLJobRepository) queryJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*jobSetCountsRow, error) { +func (r *SQLJobRepository) queryJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*jobSetCountsRow, error) { ds := r.createJobSetsDataset(opts) jobsInQueueRows := make([]*jobSetCountsRow, 0) diff --git a/internal/lookout/repository/jobs.go b/internal/lookout/repository/jobs.go index dc03c6d43c4..3d8cb0994c1 100644 --- a/internal/lookout/repository/jobs.go +++ b/internal/lookout/repository/jobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "encoding/json" "errors" "fmt" @@ -13,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/duration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" @@ -20,7 +20,7 @@ import ( "github.com/armadaproject/armada/pkg/api/lookout" ) -func (r *SQLJobRepository) GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) { +func (r *SQLJobRepository) GetJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) { if valid, jobState := validateJobStates(opts.JobStates); !valid { return nil, fmt.Errorf("unknown job state: %q", jobState) } @@ -57,7 +57,7 @@ func isJobState(val string) bool { return false } -func (r *SQLJobRepository) queryJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*JobRow, error) { +func (r *SQLJobRepository) queryJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*JobRow, error) { ds := r.createJobsDataset(opts) jobsInQueueRows := make([]*JobRow, 0) diff --git a/internal/lookout/repository/queues.go b/internal/lookout/repository/queues.go index 32b40aeb3b7..0ad1909e849 100644 --- a/internal/lookout/repository/queues.go +++ b/internal/lookout/repository/queues.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "sort" "time" @@ -10,6 +9,7 @@ import ( "github.com/gogo/protobuf/types" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -25,7 +25,7 @@ type rowsSql struct { LongestRunning string } -func (r *SQLJobRepository) GetQueueInfos(ctx context.Context) ([]*lookout.QueueInfo, error) { +func (r *SQLJobRepository) GetQueueInfos(ctx *armadacontext.Context) ([]*lookout.QueueInfo, error) { queries, err := r.getQueuesSql() if err != nil { return nil, err diff --git a/internal/lookout/repository/sql_repository.go b/internal/lookout/repository/sql_repository.go index 42d72473c5e..af59e92c6ed 100644 --- a/internal/lookout/repository/sql_repository.go +++ b/internal/lookout/repository/sql_repository.go @@ -1,12 +1,12 @@ package repository import ( - "context" "database/sql" "github.com/doug-martin/goqu/v9" _ "github.com/doug-martin/goqu/v9/dialect/postgres" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -33,9 +33,9 @@ const ( ) type JobRepository interface { - GetQueueInfos(ctx context.Context) ([]*lookout.QueueInfo, error) - GetJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) - GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) + GetQueueInfos(ctx *armadacontext.Context) ([]*lookout.QueueInfo, error) + GetJobSetInfos(ctx *armadacontext.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) + GetJobs(ctx *armadacontext.Context, opts *lookout.GetJobsRequest) ([]*lookout.JobInfo, error) } type SQLJobRepository struct { diff --git a/internal/lookout/repository/utils_test.go b/internal/lookout/repository/utils_test.go index 1c073851d0f..54fb40bcc6a 100644 --- a/internal/lookout/repository/utils_test.go +++ b/internal/lookout/repository/utils_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/api/lookout" @@ -29,7 +29,7 @@ var ( node = "node" someTimeUnix = int64(1612546858) someTime = time.Unix(someTimeUnix, 0) - ctx = context.Background() + ctx = armadacontext.Background() ) func AssertJobsAreEquivalent(t *testing.T, expected *api.Job, actual *api.Job) { diff --git a/internal/lookout/server/lookout.go b/internal/lookout/server/lookout.go index df95e7bc2de..cf48fe278aa 100644 --- a/internal/lookout/server/lookout.go +++ b/internal/lookout/server/lookout.go @@ -4,9 +4,11 @@ import ( "context" "github.com/gogo/protobuf/types" + "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/lookout/repository" "github.com/armadaproject/armada/pkg/api/lookout" ) @@ -20,7 +22,7 @@ func NewLookoutServer(jobRepository repository.JobRepository) *LookoutServer { } func (s *LookoutServer) Overview(ctx context.Context, _ *types.Empty) (*lookout.SystemOverview, error) { - queues, err := s.jobRepository.GetQueueInfos(ctx) + queues, err := s.jobRepository.GetQueueInfos(armadacontext.New(ctx, logrus.NewEntry(logrus.New()))) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query queue stats: %s", err) } @@ -28,7 +30,7 @@ func (s *LookoutServer) Overview(ctx context.Context, _ *types.Empty) (*lookout. } func (s *LookoutServer) GetJobSets(ctx context.Context, opts *lookout.GetJobSetsRequest) (*lookout.GetJobSetsResponse, error) { - jobSets, err := s.jobRepository.GetJobSetInfos(ctx, opts) + jobSets, err := s.jobRepository.GetJobSetInfos(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), opts) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query queue stats: %s", err) } @@ -36,7 +38,7 @@ func (s *LookoutServer) GetJobSets(ctx context.Context, opts *lookout.GetJobSets } func (s *LookoutServer) GetJobs(ctx context.Context, opts *lookout.GetJobsRequest) (*lookout.GetJobsResponse, error) { - jobInfos, err := s.jobRepository.GetJobs(ctx, opts) + jobInfos, err := s.jobRepository.GetJobs(armadacontext.New(ctx, logrus.NewEntry(logrus.New())), opts) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query jobs in queue: %s", err) } diff --git a/internal/lookout/testutil/db_testutil.go b/internal/lookout/testutil/db_testutil.go index 5ce57e8effa..eaba3992c15 100644 --- a/internal/lookout/testutil/db_testutil.go +++ b/internal/lookout/testutil/db_testutil.go @@ -1,7 +1,6 @@ package testutil import ( - "context" "database/sql" "fmt" @@ -9,6 +8,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookout/repository/schema" ) @@ -61,7 +61,7 @@ func WithDatabase(action func(db *sql.DB) error) error { } func WithDatabasePgx(action func(db *pgxpool.Pool) error) error { - ctx := context.Background() + ctx := armadacontext.Background() // Connect and create a dedicated database for the test // For now use database/sql for this diff --git a/internal/lookoutingester/instructions/instructions.go b/internal/lookoutingester/instructions/instructions.go index f49ac049975..2e6b314fc66 100644 --- a/internal/lookoutingester/instructions/instructions.go +++ b/internal/lookoutingester/instructions/instructions.go @@ -1,23 +1,21 @@ package instructions import ( - "context" "sort" "strings" "time" - "github.com/armadaproject/armada/internal/common/ingest/metrics" - "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "k8s.io/utils/pointer" - "github.com/armadaproject/armada/internal/common/ingest" - + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" + "github.com/armadaproject/armada/internal/common/ingest" + "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookout/repository" "github.com/armadaproject/armada/internal/lookoutingester/model" @@ -42,7 +40,7 @@ func NewInstructionConverter(metrics *metrics.Metrics, userAnnotationPrefix stri } } -func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { +func (c *InstructionConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { updateInstructions := &model.InstructionSet{ MessageIds: sequencesWithIds.MessageIds, } diff --git a/internal/lookoutingester/instructions/instructions_test.go b/internal/lookoutingester/instructions/instructions_test.go index 5510a30a695..3f4f3043101 100644 --- a/internal/lookoutingester/instructions/instructions_test.go +++ b/internal/lookoutingester/instructions/instructions_test.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "testing" "time" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -339,7 +339,7 @@ var expectedJobRunContainer = model.CreateJobRunContainerInstruction{ func TestSubmit(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(submit) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg.MessageIds, @@ -351,7 +351,7 @@ func TestSubmit(t *testing.T) { func TestDuplicate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(testfixtures.SubmitDuplicate) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ MessageIds: msg.MessageIds, } @@ -364,7 +364,7 @@ func TestDuplicate(t *testing.T) { func TestHappyPathSingleUpdate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(submit, assigned, running, jobRunSucceeded, jobSucceeded) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, JobsToUpdate: []*model.UpdateJobInstruction{&expectedLeased, &expectedRunning, &expectedJobSucceeded}, @@ -384,7 +384,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { svc := SimpleInstructionConverter() // Submit msg1 := NewMsg(submit) - instructions := svc.Convert(context.Background(), msg1) + instructions := svc.Convert(armadacontext.Background(), msg1) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg1.MessageIds, @@ -393,7 +393,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Leased msg2 := NewMsg(assigned) - instructions = svc.Convert(context.Background(), msg2) + instructions = svc.Convert(armadacontext.Background(), msg2) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedLeased}, JobRunsToCreate: []*model.CreateJobRunInstruction{&expectedLeasedRun}, @@ -403,7 +403,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Running msg3 := NewMsg(running) - instructions = svc.Convert(context.Background(), msg3) + instructions = svc.Convert(armadacontext.Background(), msg3) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedRunning}, JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedRunningRun}, @@ -413,7 +413,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Run Succeeded msg4 := NewMsg(jobRunSucceeded) - instructions = svc.Convert(context.Background(), msg4) + instructions = svc.Convert(armadacontext.Background(), msg4) expected = &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedJobRunSucceeded}, MessageIds: msg4.MessageIds, @@ -422,7 +422,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { // Job Succeeded msg5 := NewMsg(jobSucceeded) - instructions = svc.Convert(context.Background(), msg5) + instructions = svc.Convert(armadacontext.Background(), msg5) expected = &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobSucceeded}, MessageIds: msg5.MessageIds, @@ -433,7 +433,7 @@ func TestHappyPathMultiUpdate(t *testing.T) { func TestCancelled(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobCancelled) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobCancelled}, MessageIds: msg.MessageIds, @@ -444,7 +444,7 @@ func TestCancelled(t *testing.T) { func TestReprioritised(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobReprioritised) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{&expectedJobReprioritised}, MessageIds: msg.MessageIds, @@ -455,7 +455,7 @@ func TestReprioritised(t *testing.T) { func TestPreempted(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobPreempted) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedJobRunPreempted}, MessageIds: msg.MessageIds, @@ -466,7 +466,7 @@ func TestPreempted(t *testing.T) { func TestFailed(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobRunFailed) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{&expectedFailed}, JobRunContainersToCreate: []*model.CreateJobRunContainerInstruction{&expectedJobRunContainer}, @@ -478,7 +478,7 @@ func TestFailed(t *testing.T) { func TestFailedWithMissingRunId(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(jobLeaseReturned) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) jobRun := instructions.JobRunsToCreate[0] assert.NotEqual(t, eventutil.LEGACY_RUN_ID, jobRun.RunId) expected := &model.InstructionSet{ @@ -534,7 +534,7 @@ func TestHandlePodTerminated(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(podTerminated) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ MessageIds: msg.MessageIds, } @@ -565,7 +565,7 @@ func TestHandleJobLeaseReturned(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(leaseReturned) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{{ RunId: runIdString, @@ -616,7 +616,7 @@ func TestHandlePodUnschedulable(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(podUnschedulable) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobRunsToUpdate: []*model.UpdateJobRunInstruction{{ RunId: runIdString, @@ -639,7 +639,7 @@ func TestHandleDuplicate(t *testing.T) { svc := SimpleInstructionConverter() msg := NewMsg(duplicate) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToUpdate: []*model.UpdateJobInstruction{ { @@ -685,7 +685,7 @@ func TestSubmitWithNullChar(t *testing.T) { }) svc := SimpleInstructionConverter() - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) assert.Len(t, instructions.JobsToCreate, 1) assert.NotContains(t, string(instructions.JobsToCreate[0].JobProto), "\\u0000") } @@ -716,7 +716,7 @@ func TestFailedWithNullCharInError(t *testing.T) { }) svc := SimpleInstructionConverter() - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expectedJobRunsToUpdate := []*model.UpdateJobRunInstruction{ { RunId: runIdString, @@ -741,7 +741,7 @@ func TestInvalidEvent(t *testing.T) { // Check that the (valid) Submit is processed, but the invalid message is discarded svc := SimpleInstructionConverter() msg := NewMsg(invalidEvent, submit) - instructions := svc.Convert(context.Background(), msg) + instructions := svc.Convert(armadacontext.Background(), msg) expected := &model.InstructionSet{ JobsToCreate: []*model.CreateJobInstruction{&expectedSubmit}, MessageIds: msg.MessageIds, diff --git a/internal/lookoutingester/lookoutdb/insertion.go b/internal/lookoutingester/lookoutdb/insertion.go index a22b2eab29b..6009f4cadbc 100644 --- a/internal/lookoutingester/lookoutdb/insertion.go +++ b/internal/lookoutingester/lookoutdb/insertion.go @@ -1,7 +1,6 @@ package lookoutdb import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" @@ -45,7 +45,7 @@ func NewLookoutDb( // * Job Run Updates, New Job Containers // In each case we first try to bach insert the rows using the postgres copy protocol. If this fails then we try a // slower, serial insert and discard any rows that cannot be inserted. -func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSet) error { +func (l *LookoutDb) Store(ctx *armadacontext.Context, instructions *model.InstructionSet) error { jobsToUpdate := instructions.JobsToUpdate jobRunsToUpdate := instructions.JobRunsToUpdate @@ -92,7 +92,7 @@ func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSe return nil } -func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobs(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { if len(instructions) == 0 { return } @@ -109,7 +109,7 @@ func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.Create } } -func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobs(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { if len(instructions) == 0 { return } @@ -127,7 +127,7 @@ func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.Update } } -func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRuns(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { if len(instructions) == 0 { return } @@ -144,7 +144,7 @@ func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.Cre } } -func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRuns(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { if len(instructions) == 0 { return } @@ -161,7 +161,7 @@ func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.Upd } } -func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotations(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { if len(instructions) == 0 { return } @@ -178,7 +178,7 @@ func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*m } } -func (l *LookoutDb) CreateJobRunContainers(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) { +func (l *LookoutDb) CreateJobRunContainers(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) { if len(instructions) == 0 { return } @@ -195,13 +195,13 @@ func (l *LookoutDb) CreateJobRunContainers(ctx context.Context, instructions []* } } -func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.CreateJobInstruction) error { +func (l *LookoutDb) CreateJobsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") createTmp := func(tx pgx.Tx) error { _, err := tx.Exec(ctx, fmt.Sprintf(` - CREATE TEMPORARY TABLE %s + CREATE TEMPORARY TABLE %s ( job_id varchar(32), queue varchar(512), @@ -258,7 +258,7 @@ func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.C } // CreateJobsScalar will insert jobs one by one into the database -func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { sqlStatement := `INSERT INTO job (job_id, queue, owner, jobset, priority, submitted, orig_job_spec, state, job_updated) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT DO NOTHING` @@ -276,7 +276,7 @@ func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.UpdateJobInstruction) error { +func (l *LookoutDb) UpdateJobsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -337,7 +337,7 @@ func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.U }) } -func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { sqlStatement := `UPDATE job SET priority = coalesce($1, priority), @@ -360,7 +360,7 @@ func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*model.CreateJobRunInstruction) error { +func (l *LookoutDb) CreateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -410,7 +410,7 @@ func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { sqlStatement := `INSERT INTO job_run (run_id, job_id, created, cluster) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING` @@ -428,7 +428,7 @@ func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*model.UpdateJobRunInstruction) error { +func (l *LookoutDb) UpdateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -499,7 +499,7 @@ func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { sqlStatement := `UPDATE job_run SET node = coalesce($1, node), @@ -525,7 +525,7 @@ func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) error { +func (l *LookoutDb) CreateUserAnnotationsBatch(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("user_annotation_lookup") @@ -573,9 +573,9 @@ func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions }) } -func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotationsScalar(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { sqlStatement := `INSERT INTO user_annotation_lookup (job_id, key, value) - VALUES ($1, $2, $3) + VALUES ($1, $2, $3) ON CONFLICT DO NOTHING` for _, i := range instructions { err := withDatabaseRetryInsert(func() error { @@ -592,7 +592,7 @@ func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instruction } } -func (l *LookoutDb) CreateJobRunContainersBatch(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) error { +func (l *LookoutDb) CreateJobRunContainersBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) error { return withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run_container") createTmp := func(tx pgx.Tx) error { @@ -641,7 +641,7 @@ func (l *LookoutDb) CreateJobRunContainersBatch(ctx context.Context, instruction }) } -func (l *LookoutDb) CreateJobRunContainersScalar(ctx context.Context, instructions []*model.CreateJobRunContainerInstruction) { +func (l *LookoutDb) CreateJobRunContainersScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunContainerInstruction) { sqlStatement := `INSERT INTO job_run_container (run_id, container_name, exit_code) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING` @@ -659,7 +659,7 @@ func (l *LookoutDb) CreateJobRunContainersScalar(ctx context.Context, instructio } } -func batchInsert(ctx context.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, +func batchInsert(ctx *armadacontext.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, insertTmp func(pgx.Tx) error, copyToDest func(pgx.Tx) error, ) error { return pgx.BeginTxFunc(ctx, db, pgx.TxOptions{ @@ -776,7 +776,7 @@ func conflateJobRunUpdates(updates []*model.UpdateJobRunInstruction) []*model.Up // in the terminal state. If, however, the database returns a non-retryable error it will give up and simply not // filter out any events as the job state is undetermined. func filterEventsForTerminalJobs( - ctx context.Context, + ctx *armadacontext.Context, db *pgxpool.Pool, instructions []*model.UpdateJobInstruction, m *metrics.Metrics, diff --git a/internal/lookoutingester/lookoutdb/insertion_test.go b/internal/lookoutingester/lookoutdb/insertion_test.go index 25a3ff1af03..079912a68c4 100644 --- a/internal/lookoutingester/lookoutdb/insertion_test.go +++ b/internal/lookoutingester/lookoutdb/insertion_test.go @@ -1,21 +1,19 @@ package lookoutdb import ( - "context" "fmt" "sort" "testing" "time" - "github.com/armadaproject/armada/internal/common/database/lookout" - "github.com/apache/pulsar-client-go/pulsar" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/lookout/configuration" "github.com/armadaproject/armada/internal/lookout/repository" @@ -216,24 +214,24 @@ func TestCreateJobsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and test that it's idempotent - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - err = ldb.CreateJobsBatch(context.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + err = ldb.CreateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) assert.Error(t, err) assertNoRows(t, db, "job") return nil @@ -245,29 +243,29 @@ func TestUpdateJobsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Update - err = ldb.UpdateJobsBatch(context.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) require.NoError(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) - err = ldb.UpdateJobsBatch(context.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) require.NoError(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If an update is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - err = ldb.UpdateJobsBatch(context.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + err = ldb.UpdateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) assert.Error(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) @@ -280,28 +278,28 @@ func TestUpdateJobsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Insert - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Update - ldb.UpdateJobsScalar(context.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // Insert again and test that it's idempotent - ldb.UpdateJobsScalar(context.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If a update is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) - err = ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - ldb.UpdateJobsScalar(context.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + ldb.UpdateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) return nil @@ -376,13 +374,13 @@ func TestUpdateJobsWithTerminal(t *testing.T) { ldb := getTestLookoutDb(db) // Insert - ldb.CreateJobs(context.Background(), initial) + ldb.CreateJobs(armadacontext.Background(), initial) // Mark the jobs terminal - ldb.UpdateJobs(context.Background(), update1) + ldb.UpdateJobs(armadacontext.Background(), update1) // Update the jobs - these should be discarded - ldb.UpdateJobs(context.Background(), update2) + ldb.UpdateJobs(armadacontext.Background(), update2) // Assert the states are still terminal job := getJob(t, db, jobIdString) @@ -403,22 +401,22 @@ func TestCreateJobsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Simple create - ldb.CreateJobsScalar(context.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and check for idempotency - ldb.CreateJobsScalar(context.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should update only the good rows - _, err := ldb.db.Exec(context.Background(), "DELETE FROM job") + _, err := ldb.db.Exec(armadacontext.Background(), "DELETE FROM job") require.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - ldb.CreateJobsScalar(context.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + ldb.CreateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) return nil @@ -430,28 +428,28 @@ func TestCreateJobRunsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(context.Background(), "DELETE FROM job_run") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run") require.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) assert.Error(t, err) assertNoRows(t, db, "job_run") return nil @@ -463,26 +461,26 @@ func TestCreateJobRunsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - ldb.CreateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - ldb.CreateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we create rows that can be created - _, err = db.Exec(context.Background(), "DELETE FROM job_run") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run") require.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - ldb.CreateJobRunsScalar(context.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + ldb.CreateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) return nil @@ -494,33 +492,33 @@ func TestUpdateJobRunsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) // Update - err = ldb.UpdateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - err = ldb.UpdateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(context.Background(), "DELETE FROM job_run;") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run;") require.NoError(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) - err = ldb.UpdateJobRunsBatch(context.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) assert.Error(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, run) @@ -533,33 +531,33 @@ func TestUpdateJobRunsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) require.NoError(t, err) // Update - ldb.UpdateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - ldb.UpdateJobRunsScalar(context.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) require.NoError(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(context.Background(), "DELETE FROM job_run;") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run;") require.NoError(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(context.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - ldb.UpdateJobRunsScalar(context.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + ldb.UpdateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) run = getJobRun(t, ldb.db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) return nil @@ -571,28 +569,28 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - err = ldb.CreateUserAnnotationsBatch(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) require.NoError(t, err) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - err = ldb.CreateUserAnnotationsBatch(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) require.NoError(t, err) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(context.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") require.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - err = ldb.CreateUserAnnotationsBatch(context.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) assert.Error(t, err) assertNoRows(t, ldb.db, "user_annotation_lookup") return nil @@ -603,7 +601,7 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { func TestEmptyUpdate(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) - storeErr := ldb.Store(context.Background(), &model.InstructionSet{}) + storeErr := ldb.Store(armadacontext.Background(), &model.InstructionSet{}) require.NoError(t, storeErr) assertNoRows(t, ldb.db, "job") assertNoRows(t, ldb.db, "job_run") @@ -618,26 +616,26 @@ func TestCreateUserAnnotationsScalar(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Need to make sure we have a job - err := ldb.CreateJobsBatch(context.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) require.NoError(t, err) // Insert - ldb.CreateUserAnnotationsScalar(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - ldb.CreateUserAnnotationsScalar(context.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(context.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") require.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - ldb.CreateUserAnnotationsScalar(context.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) annotation = getUserAnnotationLookup(t, ldb.db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) return nil @@ -649,7 +647,7 @@ func TestUpdate(t *testing.T) { err := testutil.WithDatabasePgx(func(db *pgxpool.Pool) error { ldb := getTestLookoutDb(db) // Do the update - storeErr := ldb.Store(context.Background(), defaultInstructionSet()) + storeErr := ldb.Store(armadacontext.Background(), defaultInstructionSet()) require.NoError(t, storeErr) job := getJob(t, ldb.db, jobIdString) jobRun := getJobRun(t, ldb.db, runIdString) @@ -748,7 +746,7 @@ func TestConflateJobRunUpdates(T *testing.T) { func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { job := JobRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT job_id, queue, owner, jobset, priority, submitted, state, duplicate, job_updated, orig_job_spec, cancelled FROM job WHERE job_id = $1`, jobId) err := r.Scan( @@ -771,7 +769,7 @@ func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { run := JobRunRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT run_id, job_id, cluster, node, created, started, finished, succeeded, error, pod_number, unable_to_schedule, preempted FROM job_run WHERE run_id = $1`, runId) err := r.Scan( @@ -795,7 +793,7 @@ func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { func getJobRunContainer(t *testing.T, db *pgxpool.Pool, runId string) JobRunContainerRow { container := JobRunContainerRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT run_id, container_name, exit_code FROM job_run_container WHERE run_id = $1`, runId) err := r.Scan(&container.RunId, &container.ContainerName, &container.ExitCode) @@ -806,7 +804,7 @@ func getJobRunContainer(t *testing.T, db *pgxpool.Pool, runId string) JobRunCont func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserAnnotationRow { annotation := UserAnnotationRow{} r := db.QueryRow( - context.Background(), + armadacontext.Background(), `SELECT job_id, key, value FROM user_annotation_lookup WHERE job_id = $1`, jobId) err := r.Scan(&annotation.JobId, &annotation.Key, &annotation.Value) @@ -816,7 +814,7 @@ func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserA func assertNoRows(t *testing.T, db *pgxpool.Pool, table string) { var count int - r := db.QueryRow(context.Background(), fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) + r := db.QueryRow(armadacontext.Background(), fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) err := r.Scan(&count) require.NoError(t, err) assert.Equal(t, 0, count) diff --git a/internal/lookoutingesterv2/benchmark/benchmark.go b/internal/lookoutingesterv2/benchmark/benchmark.go index 6c808ca14f3..953ccb85483 100644 --- a/internal/lookoutingesterv2/benchmark/benchmark.go +++ b/internal/lookoutingesterv2/benchmark/benchmark.go @@ -1,7 +1,6 @@ package benchmark import ( - "context" "fmt" "math" "math/rand" @@ -12,6 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutingesterv2/configuration" @@ -51,7 +51,7 @@ func benchmarkSubmissions1000(b *testing.B, config configuration.LookoutIngester withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) b.StartTimer() - err := ldb.Store(context.TODO(), instructions) + err := ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -69,7 +69,7 @@ func benchmarkSubmissions10000(b *testing.B, config configuration.LookoutIngeste withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) b.StartTimer() - err := ldb.Store(context.TODO(), instructions) + err := ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -99,12 +99,12 @@ func benchmarkUpdates1000(b *testing.B, config configuration.LookoutIngesterV2Co withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) - err := ldb.Store(context.TODO(), initialInstructions) + err := ldb.Store(armadacontext.TODO(), initialInstructions) if err != nil { panic(err) } b.StartTimer() - err = ldb.Store(context.TODO(), instructions) + err = ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } @@ -134,12 +134,12 @@ func benchmarkUpdates10000(b *testing.B, config configuration.LookoutIngesterV2C withDbBenchmark(b, config, func(b *testing.B, db *pgxpool.Pool) { ldb := lookoutdb.NewLookoutDb(db, metrics.Get(), 2, 10) - err := ldb.Store(context.TODO(), initialInstructions) + err := ldb.Store(armadacontext.TODO(), initialInstructions) if err != nil { panic(err) } b.StartTimer() - err = ldb.Store(context.TODO(), instructions) + err = ldb.Store(armadacontext.TODO(), instructions) if err != nil { panic(err) } diff --git a/internal/lookoutingesterv2/instructions/instructions.go b/internal/lookoutingesterv2/instructions/instructions.go index 49e519f5eb4..25decf00f2a 100644 --- a/internal/lookoutingesterv2/instructions/instructions.go +++ b/internal/lookoutingesterv2/instructions/instructions.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "fmt" "sort" "strings" @@ -14,6 +13,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" @@ -65,7 +65,7 @@ func (c *InstructionConverter) IsLegacy() bool { return c.useLegacyEventConversion } -func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { +func (c *InstructionConverter) Convert(ctx *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *model.InstructionSet { updateInstructions := &model.InstructionSet{ MessageIds: sequencesWithIds.MessageIds, } @@ -77,7 +77,7 @@ func (c *InstructionConverter) Convert(ctx context.Context, sequencesWithIds *in } func (c *InstructionConverter) convertSequence( - ctx context.Context, + ctx *armadacontext.Context, sequence *armadaevents.EventSequence, update *model.InstructionSet, ) { diff --git a/internal/lookoutingesterv2/instructions/instructions_test.go b/internal/lookoutingesterv2/instructions/instructions_test.go index d70d7d3900d..36e58983283 100644 --- a/internal/lookoutingesterv2/instructions/instructions_test.go +++ b/internal/lookoutingesterv2/instructions/instructions_test.go @@ -1,7 +1,6 @@ package instructions import ( - "context" "fmt" "strings" "testing" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" @@ -560,7 +560,7 @@ func TestConvert(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, tc.useLegacyEventConversion) - instructionSet := converter.Convert(context.TODO(), tc.events) + instructionSet := converter.Convert(armadacontext.TODO(), tc.events) assert.Equal(t, tc.expected.JobsToCreate, instructionSet.JobsToCreate) assert.Equal(t, tc.expected.JobsToUpdate, instructionSet.JobsToUpdate) assert.Equal(t, tc.expected.JobRunsToCreate, instructionSet.JobRunsToCreate) @@ -571,7 +571,7 @@ func TestConvert(t *testing.T) { func TestFailedWithMissingRunId(t *testing.T) { converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) - instructions := converter.Convert(context.Background(), &ingest.EventSequencesWithIds{ + instructions := converter.Convert(armadacontext.Background(), &ingest.EventSequencesWithIds{ EventSequences: []*armadaevents.EventSequence{testfixtures.NewEventSequence(testfixtures.JobLeaseReturned)}, MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, }) @@ -631,7 +631,7 @@ func TestTruncatesStringsThatAreTooLong(t *testing.T) { } converter := NewInstructionConverter(metrics.Get(), userAnnotationPrefix, &compress.NoOpCompressor{}, true) - actual := converter.Convert(context.TODO(), events) + actual := converter.Convert(armadacontext.TODO(), events) // String lengths obtained from database schema assert.Len(t, actual.JobsToCreate[0].Queue, 512) diff --git a/internal/lookoutingesterv2/lookoutdb/insertion.go b/internal/lookoutingesterv2/lookoutdb/insertion.go index c5378543df0..2e13c453213 100644 --- a/internal/lookoutingesterv2/lookoutdb/insertion.go +++ b/internal/lookoutingesterv2/lookoutdb/insertion.go @@ -1,7 +1,6 @@ package lookoutdb import ( - "context" "fmt" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/database/lookout" @@ -36,7 +36,7 @@ func NewLookoutDb(db *pgxpool.Pool, metrics *metrics.Metrics, maxAttempts int, m // * Job Run Updates // In each case we first try to bach insert the rows using the postgres copy protocol. If this fails then we try a // slower, serial insert and discard any rows that cannot be inserted. -func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSet) error { +func (l *LookoutDb) Store(ctx *armadacontext.Context, instructions *model.InstructionSet) error { // We might have multiple updates for the same job or job run // These can be conflated to help performance jobsToUpdate := conflateJobUpdates(instructions.JobsToUpdate) @@ -68,7 +68,7 @@ func (l *LookoutDb) Store(ctx context.Context, instructions *model.InstructionSe return nil } -func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobs(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { if len(instructions) == 0 { return } @@ -79,7 +79,7 @@ func (l *LookoutDb) CreateJobs(ctx context.Context, instructions []*model.Create } } -func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobs(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { if len(instructions) == 0 { return } @@ -91,7 +91,7 @@ func (l *LookoutDb) UpdateJobs(ctx context.Context, instructions []*model.Update } } -func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRuns(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { if len(instructions) == 0 { return } @@ -102,7 +102,7 @@ func (l *LookoutDb) CreateJobRuns(ctx context.Context, instructions []*model.Cre } } -func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRuns(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { if len(instructions) == 0 { return } @@ -113,7 +113,7 @@ func (l *LookoutDb) UpdateJobRuns(ctx context.Context, instructions []*model.Upd } } -func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotations(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { if len(instructions) == 0 { return } @@ -124,7 +124,7 @@ func (l *LookoutDb) CreateUserAnnotations(ctx context.Context, instructions []*m } } -func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.CreateJobInstruction) error { +func (l *LookoutDb) CreateJobsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -231,7 +231,7 @@ func (l *LookoutDb) CreateJobsBatch(ctx context.Context, instructions []*model.C } // CreateJobsScalar will insert jobs one by one into the database -func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model.CreateJobInstruction) { +func (l *LookoutDb) CreateJobsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobInstruction) { sqlStatement := `INSERT INTO job ( job_id, queue, @@ -279,7 +279,7 @@ func (l *LookoutDb) CreateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.UpdateJobInstruction) error { +func (l *LookoutDb) UpdateJobsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job") @@ -358,7 +358,7 @@ func (l *LookoutDb) UpdateJobsBatch(ctx context.Context, instructions []*model.U }) } -func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model.UpdateJobInstruction) { +func (l *LookoutDb) UpdateJobsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobInstruction) { sqlStatement := `UPDATE job SET priority = coalesce($2, priority), @@ -393,7 +393,7 @@ func (l *LookoutDb) UpdateJobsScalar(ctx context.Context, instructions []*model. } } -func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*model.CreateJobRunInstruction) error { +func (l *LookoutDb) CreateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -464,7 +464,7 @@ func (l *LookoutDb) CreateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*model.CreateJobRunInstruction) { +func (l *LookoutDb) CreateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.CreateJobRunInstruction) { sqlStatement := `INSERT INTO job_run ( run_id, job_id, @@ -496,7 +496,7 @@ func (l *LookoutDb) CreateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*model.UpdateJobRunInstruction) error { +func (l *LookoutDb) UpdateJobRunsBatch(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("job_run") @@ -571,7 +571,7 @@ func (l *LookoutDb) UpdateJobRunsBatch(ctx context.Context, instructions []*mode }) } -func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*model.UpdateJobRunInstruction) { +func (l *LookoutDb) UpdateJobRunsScalar(ctx *armadacontext.Context, instructions []*model.UpdateJobRunInstruction) { sqlStatement := `UPDATE job_run SET node = coalesce($2, node), @@ -604,7 +604,7 @@ func (l *LookoutDb) UpdateJobRunsScalar(ctx context.Context, instructions []*mod } } -func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) error { +func (l *LookoutDb) CreateUserAnnotationsBatch(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) error { return l.withDatabaseRetryInsert(func() error { tmpTable := database.UniqueTableName("user_annotation_lookup") @@ -667,7 +667,7 @@ func (l *LookoutDb) CreateUserAnnotationsBatch(ctx context.Context, instructions }) } -func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instructions []*model.CreateUserAnnotationInstruction) { +func (l *LookoutDb) CreateUserAnnotationsScalar(ctx *armadacontext.Context, instructions []*model.CreateUserAnnotationInstruction) { sqlStatement := `INSERT INTO user_annotation_lookup ( job_id, key, @@ -696,7 +696,7 @@ func (l *LookoutDb) CreateUserAnnotationsScalar(ctx context.Context, instruction } } -func batchInsert(ctx context.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, +func batchInsert(ctx *armadacontext.Context, db *pgxpool.Pool, createTmp func(pgx.Tx) error, insertTmp func(pgx.Tx) error, copyToDest func(pgx.Tx) error, ) error { return pgx.BeginTxFunc(ctx, db, pgx.TxOptions{ @@ -834,7 +834,7 @@ type updateInstructionsForJob struct { // in the terminal state. If, however, the database returns a non-retryable error it will give up and simply not // filter out any events as the job state is undetermined. func (l *LookoutDb) filterEventsForTerminalJobs( - ctx context.Context, + ctx *armadacontext.Context, db *pgxpool.Pool, instructions []*model.UpdateJobInstruction, m *metrics.Metrics, diff --git a/internal/lookoutingesterv2/lookoutdb/insertion_test.go b/internal/lookoutingesterv2/lookoutdb/insertion_test.go index 9de584df3fa..13b64c12365 100644 --- a/internal/lookoutingesterv2/lookoutdb/insertion_test.go +++ b/internal/lookoutingesterv2/lookoutdb/insertion_test.go @@ -1,7 +1,6 @@ package lookoutdb import ( - ctx "context" "fmt" "sort" "testing" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/lookoutingesterv2/metrics" @@ -202,24 +202,24 @@ func TestCreateJobsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and test that it's idempotent - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - err = ldb.CreateJobsBatch(ctx.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + err = ldb.CreateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) assert.Error(t, err) assertNoRows(t, db, "job") return nil @@ -231,29 +231,29 @@ func TestUpdateJobsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Update - err = ldb.UpdateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) assert.Nil(t, err) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) - err = ldb.UpdateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToUpdate) + err = ldb.UpdateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) assert.Nil(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If an update is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - err = ldb.UpdateJobsBatch(ctx.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + err = ldb.UpdateJobsBatch(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) assert.Error(t, err) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) @@ -266,28 +266,28 @@ func TestUpdateJobsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Insert - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Update - ldb.UpdateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // Insert again and test that it's idempotent - ldb.UpdateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToUpdate) + ldb.UpdateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToUpdate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) // If a update is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) - err = ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err = ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) invalidUpdate := &model.UpdateJobInstruction{ JobId: invalidId, } - ldb.UpdateJobsScalar(ctx.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) + ldb.UpdateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToUpdate, invalidUpdate)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterUpdate, job) return nil @@ -399,13 +399,13 @@ func TestUpdateJobsWithTerminal(t *testing.T) { ldb := NewLookoutDb(db, m, 2, 10) // Insert - ldb.CreateJobs(ctx.Background(), initial) + ldb.CreateJobs(armadacontext.Background(), initial) // Mark the jobs terminal - ldb.UpdateJobs(ctx.Background(), update1) + ldb.UpdateJobs(armadacontext.Background(), update1) // Update the jobs - these should be discarded - ldb.UpdateJobs(ctx.Background(), update2) + ldb.UpdateJobs(armadacontext.Background(), update2) // Assert the states are still terminal job := getJob(t, db, jobIdString) @@ -427,22 +427,22 @@ func TestCreateJobsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Simple create - ldb.CreateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job := getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // Insert again and check for idempotency - ldb.CreateJobsScalar(ctx.Background(), defaultInstructionSet().JobsToCreate) + ldb.CreateJobsScalar(armadacontext.Background(), defaultInstructionSet().JobsToCreate) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) // If a row is bad then we should update only the good rows - _, err := ldb.db.Exec(ctx.Background(), "DELETE FROM job") + _, err := ldb.db.Exec(armadacontext.Background(), "DELETE FROM job") assert.NoError(t, err) invalidJob := &model.CreateJobInstruction{ JobId: invalidId, } - ldb.CreateJobsScalar(ctx.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) + ldb.CreateJobsScalar(armadacontext.Background(), append(defaultInstructionSet().JobsToCreate, invalidJob)) job = getJob(t, db, jobIdString) assert.Equal(t, expectedJobAfterSubmit, job) return nil @@ -454,28 +454,28 @@ func TestCreateJobRunsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM job_run") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run") assert.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) assert.Error(t, err) assertNoRows(t, db, "job_run") return nil @@ -487,26 +487,26 @@ func TestCreateJobRunsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job, so we can satisfy PK - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - ldb.CreateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // Insert again and test that it's idempotent - ldb.CreateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + ldb.CreateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) // If a row is bad then we create rows that can be created - _, err = db.Exec(ctx.Background(), "DELETE FROM job_run") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run") assert.NoError(t, err) invalidRun := &model.CreateJobRunInstruction{ RunId: invalidId, } - ldb.CreateJobRunsScalar(ctx.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) + ldb.CreateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToCreate, invalidRun)) job = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, job) return nil @@ -518,33 +518,33 @@ func TestUpdateJobRunsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) // Update - err = ldb.UpdateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - err = ldb.UpdateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should return an error and no updates should happen - _, err = db.Exec(ctx.Background(), "DELETE FROM job_run;") + _, err = db.Exec(armadacontext.Background(), "DELETE FROM job_run;") assert.Nil(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - err = ldb.UpdateJobRunsBatch(ctx.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + err = ldb.UpdateJobRunsBatch(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) assert.Error(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRun, run) @@ -557,33 +557,33 @@ func TestUpdateJobRunsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job and run - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) // Update - ldb.UpdateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run := getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // Update again and test that it's idempotent - ldb.UpdateJobRunsScalar(ctx.Background(), defaultInstructionSet().JobRunsToUpdate) + ldb.UpdateJobRunsScalar(armadacontext.Background(), defaultInstructionSet().JobRunsToUpdate) assert.Nil(t, err) run = getJobRun(t, db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM job_run;") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_run;") assert.Nil(t, err) invalidRun := &model.UpdateJobRunInstruction{ RunId: invalidId, } - err = ldb.CreateJobRunsBatch(ctx.Background(), defaultInstructionSet().JobRunsToCreate) + err = ldb.CreateJobRunsBatch(armadacontext.Background(), defaultInstructionSet().JobRunsToCreate) assert.Nil(t, err) - ldb.UpdateJobRunsScalar(ctx.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) + ldb.UpdateJobRunsScalar(armadacontext.Background(), append(defaultInstructionSet().JobRunsToUpdate, invalidRun)) run = getJobRun(t, ldb.db, runIdString) assert.Equal(t, expectedJobRunAfterUpdate, run) return nil @@ -595,28 +595,28 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) assert.Nil(t, err) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) assert.Nil(t, err) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should return an error and no updates should happen - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") assert.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - err = ldb.CreateUserAnnotationsBatch(ctx.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + err = ldb.CreateUserAnnotationsBatch(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) assert.Error(t, err) assertNoRows(t, ldb.db, "user_annotation_lookup") return nil @@ -627,7 +627,7 @@ func TestCreateUserAnnotationsBatch(t *testing.T) { func TestStoreWithEmptyInstructionSet(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) - err := ldb.Store(ctx.Background(), &model.InstructionSet{ + err := ldb.Store(armadacontext.Background(), &model.InstructionSet{ MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, }) assert.NoError(t, err) @@ -643,26 +643,26 @@ func TestCreateUserAnnotationsScalar(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Need to make sure we have a job - err := ldb.CreateJobsBatch(ctx.Background(), defaultInstructionSet().JobsToCreate) + err := ldb.CreateJobsBatch(armadacontext.Background(), defaultInstructionSet().JobsToCreate) assert.Nil(t, err) // Insert - ldb.CreateUserAnnotationsScalar(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation := getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // Insert again and test that it's idempotent - ldb.CreateUserAnnotationsScalar(ctx.Background(), defaultInstructionSet().UserAnnotationsToCreate) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), defaultInstructionSet().UserAnnotationsToCreate) annotation = getUserAnnotationLookup(t, db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) // If a row is bad then we should update the rows we can - _, err = ldb.db.Exec(ctx.Background(), "DELETE FROM user_annotation_lookup") + _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM user_annotation_lookup") assert.NoError(t, err) invalidAnnotation := &model.CreateUserAnnotationInstruction{ JobId: invalidId, } - ldb.CreateUserAnnotationsScalar(ctx.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) + ldb.CreateUserAnnotationsScalar(armadacontext.Background(), append(defaultInstructionSet().UserAnnotationsToCreate, invalidAnnotation)) annotation = getUserAnnotationLookup(t, ldb.db, jobIdString) assert.Equal(t, expectedUserAnnotation, annotation) return nil @@ -674,7 +674,7 @@ func TestStore(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { ldb := NewLookoutDb(db, m, 2, 10) // Do the update - err := ldb.Store(ctx.Background(), defaultInstructionSet()) + err := ldb.Store(armadacontext.Background(), defaultInstructionSet()) assert.NoError(t, err) job := getJob(t, ldb.db, jobIdString) @@ -843,7 +843,7 @@ func TestStoreNullValue(t *testing.T) { ldb := NewLookoutDb(db, m, 2, 10) // Do the update - err := ldb.Store(ctx.Background(), instructions) + err := ldb.Store(armadacontext.Background(), instructions) assert.NoError(t, err) job := getJob(t, ldb.db, jobIdString) @@ -875,7 +875,7 @@ func TestStoreEventsForAlreadyTerminalJobs(t *testing.T) { } // Create the jobs in the DB - err := ldb.Store(ctx.Background(), baseInstructions) + err := ldb.Store(armadacontext.Background(), baseInstructions) assert.NoError(t, err) mutateInstructions := &model.InstructionSet{ @@ -895,7 +895,7 @@ func TestStoreEventsForAlreadyTerminalJobs(t *testing.T) { } // Update the jobs in the DB - err = ldb.Store(ctx.Background(), mutateInstructions) + err = ldb.Store(armadacontext.Background(), mutateInstructions) assert.NoError(t, err) for _, jobId := range []string{"job-1", "job-2", "job-3"} { @@ -941,7 +941,7 @@ func makeUpdateJobInstruction(jobId string, state int32) *model.UpdateJobInstruc func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { job := JobRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT job_id, queue, @@ -992,7 +992,7 @@ func getJob(t *testing.T, db *pgxpool.Pool, jobId string) JobRow { func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { run := JobRunRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT run_id, job_id, @@ -1025,7 +1025,7 @@ func getJobRun(t *testing.T, db *pgxpool.Pool, runId string) JobRunRow { func getUserAnnotationLookup(t *testing.T, db *pgxpool.Pool, jobId string) UserAnnotationRow { annotation := UserAnnotationRow{} r := db.QueryRow( - ctx.Background(), + armadacontext.Background(), `SELECT job_id, key, value, queue, jobset FROM user_annotation_lookup WHERE job_id = $1`, jobId) err := r.Scan(&annotation.JobId, &annotation.Key, &annotation.Value, &annotation.Queue, &annotation.JobSet) @@ -1037,7 +1037,7 @@ func assertNoRows(t *testing.T, db *pgxpool.Pool, table string) { t.Helper() var count int query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) - r := db.QueryRow(ctx.Background(), query) + r := db.QueryRow(armadacontext.Background(), query) err := r.Scan(&count) assert.NoError(t, err) assert.Equal(t, 0, count) diff --git a/internal/lookoutv2/application.go b/internal/lookoutv2/application.go index ca6844f8b32..0b0fa42bb86 100644 --- a/internal/lookoutv2/application.go +++ b/internal/lookoutv2/application.go @@ -3,10 +3,12 @@ package lookoutv2 import ( + "github.com/caarlos0/log" "github.com/go-openapi/loads" "github.com/go-openapi/runtime/middleware" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" @@ -38,6 +40,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { // create new service API api := operations.NewLookoutAPI(swaggerSpec) + logger := logrus.NewEntry(logrus.New()) + api.GetHealthHandler = operations.GetHealthHandlerFunc( func(params operations.GetHealthParams) middleware.Responder { return operations.NewGetHealthOK().WithPayload("Health check passed") @@ -53,7 +57,7 @@ func Serve(configuration configuration.LookoutV2Configuration) error { skip = int(*params.GetJobsRequest.Skip) } result, err := getJobsRepo.GetJobs( - params.HTTPRequest.Context(), + armadacontext.New(params.HTTPRequest.Context(), logger), filters, params.GetJobsRequest.ActiveJobSets, order, @@ -78,7 +82,7 @@ func Serve(configuration configuration.LookoutV2Configuration) error { skip = int(*params.GroupJobsRequest.Skip) } result, err := groupJobsRepo.GroupBy( - params.HTTPRequest.Context(), + armadacontext.New(params.HTTPRequest.Context(), logger), filters, params.GroupJobsRequest.ActiveJobSets, order, @@ -98,7 +102,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { api.GetJobRunErrorHandler = operations.GetJobRunErrorHandlerFunc( func(params operations.GetJobRunErrorParams) middleware.Responder { - result, err := getJobRunErrorRepo.GetJobRunError(params.HTTPRequest.Context(), params.GetJobRunErrorRequest.RunID) + ctx := armadacontext.New(params.HTTPRequest.Context(), logger) + result, err := getJobRunErrorRepo.GetJobRunError(ctx, params.GetJobRunErrorRequest.RunID) if err != nil { return operations.NewGetJobRunErrorBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) } @@ -110,7 +115,8 @@ func Serve(configuration configuration.LookoutV2Configuration) error { api.GetJobSpecHandler = operations.GetJobSpecHandlerFunc( func(params operations.GetJobSpecParams) middleware.Responder { - result, err := getJobSpecRepo.GetJobSpec(params.HTTPRequest.Context(), params.GetJobSpecRequest.JobID) + ctx := armadacontext.New(params.HTTPRequest.Context(), logger) + result, err := getJobSpecRepo.GetJobSpec(ctx, params.GetJobSpecRequest.JobID) if err != nil { return operations.NewGetJobSpecBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) } diff --git a/internal/lookoutv2/gen/restapi/doc.go b/internal/lookoutv2/gen/restapi/doc.go index 23beb22a1a0..a8686cb04ea 100644 --- a/internal/lookoutv2/gen/restapi/doc.go +++ b/internal/lookoutv2/gen/restapi/doc.go @@ -2,18 +2,18 @@ // Package restapi Lookout v2 API // -// Schemes: -// http -// Host: localhost -// BasePath: / -// Version: 2.0.0 +// Schemes: +// http +// Host: localhost +// BasePath: / +// Version: 2.0.0 // -// Consumes: -// - application/json +// Consumes: +// - application/json // -// Produces: -// - application/json -// - text/plain +// Produces: +// - application/json +// - text/plain // // swagger:meta package restapi diff --git a/internal/lookoutv2/gen/restapi/operations/get_health.go b/internal/lookoutv2/gen/restapi/operations/get_health.go index 16cd6803823..d7c8a7dc5ad 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_health.go +++ b/internal/lookoutv2/gen/restapi/operations/get_health.go @@ -29,10 +29,10 @@ func NewGetHealth(ctx *middleware.Context, handler GetHealthHandler) *GetHealth return &GetHealth{Context: ctx, Handler: handler} } -/* GetHealth swagger:route GET /health getHealth +/* + GetHealth swagger:route GET /health getHealth GetHealth get health API - */ type GetHealth struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_health_responses.go b/internal/lookoutv2/gen/restapi/operations/get_health_responses.go index 032b8c2cb0d..c54a26244c4 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_health_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_health_responses.go @@ -14,7 +14,8 @@ import ( // GetHealthOKCode is the HTTP code returned for type GetHealthOK const GetHealthOKCode int = 200 -/*GetHealthOK OK +/* +GetHealthOK OK swagger:response getHealthOK */ @@ -56,7 +57,8 @@ func (o *GetHealthOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pro // GetHealthBadRequestCode is the HTTP code returned for type GetHealthBadRequest const GetHealthBadRequestCode int = 400 -/*GetHealthBadRequest Error response +/* +GetHealthBadRequest Error response swagger:response getHealthBadRequest */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go b/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go index 537d2663379..f8add74ee45 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_run_error.go @@ -34,10 +34,10 @@ func NewGetJobRunError(ctx *middleware.Context, handler GetJobRunErrorHandler) * return &GetJobRunError{Context: ctx, Handler: handler} } -/* GetJobRunError swagger:route POST /api/v1/jobRunError getJobRunError +/* + GetJobRunError swagger:route POST /api/v1/jobRunError getJobRunError GetJobRunError get job run error API - */ type GetJobRunError struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go b/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go index ff1a82e47c2..e8a17e5b37d 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_run_error_responses.go @@ -16,7 +16,8 @@ import ( // GetJobRunErrorOKCode is the HTTP code returned for type GetJobRunErrorOK const GetJobRunErrorOKCode int = 200 -/*GetJobRunErrorOK Returns error for specific job run (if present) +/* +GetJobRunErrorOK Returns error for specific job run (if present) swagger:response getJobRunErrorOK */ @@ -60,7 +61,8 @@ func (o *GetJobRunErrorOK) WriteResponse(rw http.ResponseWriter, producer runtim // GetJobRunErrorBadRequestCode is the HTTP code returned for type GetJobRunErrorBadRequest const GetJobRunErrorBadRequestCode int = 400 -/*GetJobRunErrorBadRequest Error response +/* +GetJobRunErrorBadRequest Error response swagger:response getJobRunErrorBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobRunErrorBadRequest) WriteResponse(rw http.ResponseWriter, produce } } -/*GetJobRunErrorDefault Error response +/* +GetJobRunErrorDefault Error response swagger:response getJobRunErrorDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_spec.go b/internal/lookoutv2/gen/restapi/operations/get_job_spec.go index a0ee4726d38..74055af08f8 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_spec.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_spec.go @@ -34,10 +34,10 @@ func NewGetJobSpec(ctx *middleware.Context, handler GetJobSpecHandler) *GetJobSp return &GetJobSpec{Context: ctx, Handler: handler} } -/* GetJobSpec swagger:route POST /api/v1/jobSpec getJobSpec +/* + GetJobSpec swagger:route POST /api/v1/jobSpec getJobSpec GetJobSpec get job spec API - */ type GetJobSpec struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go b/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go index ccccd693330..8c4776d0f47 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_job_spec_responses.go @@ -16,7 +16,8 @@ import ( // GetJobSpecOKCode is the HTTP code returned for type GetJobSpecOK const GetJobSpecOKCode int = 200 -/*GetJobSpecOK Returns raw Job spec +/* +GetJobSpecOK Returns raw Job spec swagger:response getJobSpecOK */ @@ -60,7 +61,8 @@ func (o *GetJobSpecOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pr // GetJobSpecBadRequestCode is the HTTP code returned for type GetJobSpecBadRequest const GetJobSpecBadRequestCode int = 400 -/*GetJobSpecBadRequest Error response +/* +GetJobSpecBadRequest Error response swagger:response getJobSpecBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobSpecBadRequest) WriteResponse(rw http.ResponseWriter, producer ru } } -/*GetJobSpecDefault Error response +/* +GetJobSpecDefault Error response swagger:response getJobSpecDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/get_jobs.go b/internal/lookoutv2/gen/restapi/operations/get_jobs.go index 76689ed77d0..b498f593901 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_jobs.go +++ b/internal/lookoutv2/gen/restapi/operations/get_jobs.go @@ -37,10 +37,10 @@ func NewGetJobs(ctx *middleware.Context, handler GetJobsHandler) *GetJobs { return &GetJobs{Context: ctx, Handler: handler} } -/* GetJobs swagger:route POST /api/v1/jobs getJobs +/* + GetJobs swagger:route POST /api/v1/jobs getJobs GetJobs get jobs API - */ type GetJobs struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go b/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go index 2b1802191f6..5af80b4f316 100644 --- a/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/get_jobs_responses.go @@ -16,7 +16,8 @@ import ( // GetJobsOKCode is the HTTP code returned for type GetJobsOK const GetJobsOKCode int = 200 -/*GetJobsOK Returns jobs from API +/* +GetJobsOK Returns jobs from API swagger:response getJobsOK */ @@ -60,7 +61,8 @@ func (o *GetJobsOK) WriteResponse(rw http.ResponseWriter, producer runtime.Produ // GetJobsBadRequestCode is the HTTP code returned for type GetJobsBadRequest const GetJobsBadRequestCode int = 400 -/*GetJobsBadRequest Error response +/* +GetJobsBadRequest Error response swagger:response getJobsBadRequest */ @@ -101,7 +103,8 @@ func (o *GetJobsBadRequest) WriteResponse(rw http.ResponseWriter, producer runti } } -/*GetJobsDefault Error response +/* +GetJobsDefault Error response swagger:response getJobsDefault */ diff --git a/internal/lookoutv2/gen/restapi/operations/group_jobs.go b/internal/lookoutv2/gen/restapi/operations/group_jobs.go index 4225045294b..208d7856c68 100644 --- a/internal/lookoutv2/gen/restapi/operations/group_jobs.go +++ b/internal/lookoutv2/gen/restapi/operations/group_jobs.go @@ -37,10 +37,10 @@ func NewGroupJobs(ctx *middleware.Context, handler GroupJobsHandler) *GroupJobs return &GroupJobs{Context: ctx, Handler: handler} } -/* GroupJobs swagger:route POST /api/v1/jobGroups groupJobs +/* + GroupJobs swagger:route POST /api/v1/jobGroups groupJobs GroupJobs group jobs API - */ type GroupJobs struct { Context *middleware.Context diff --git a/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go b/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go index ff442c870bc..b34b787fbbf 100644 --- a/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go +++ b/internal/lookoutv2/gen/restapi/operations/group_jobs_responses.go @@ -16,7 +16,8 @@ import ( // GroupJobsOKCode is the HTTP code returned for type GroupJobsOK const GroupJobsOKCode int = 200 -/*GroupJobsOK Returns job groups from API +/* +GroupJobsOK Returns job groups from API swagger:response groupJobsOK */ @@ -60,7 +61,8 @@ func (o *GroupJobsOK) WriteResponse(rw http.ResponseWriter, producer runtime.Pro // GroupJobsBadRequestCode is the HTTP code returned for type GroupJobsBadRequest const GroupJobsBadRequestCode int = 400 -/*GroupJobsBadRequest Error response +/* +GroupJobsBadRequest Error response swagger:response groupJobsBadRequest */ @@ -101,7 +103,8 @@ func (o *GroupJobsBadRequest) WriteResponse(rw http.ResponseWriter, producer run } } -/*GroupJobsDefault Error response +/* +GroupJobsDefault Error response swagger:response groupJobsDefault */ diff --git a/internal/lookoutv2/pruner/pruner.go b/internal/lookoutv2/pruner/pruner.go index 18ee81c8da1..946917fe30a 100644 --- a/internal/lookoutv2/pruner/pruner.go +++ b/internal/lookoutv2/pruner/pruner.go @@ -1,16 +1,17 @@ package pruner import ( - "context" "time" "github.com/jackc/pgx/v5" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) -func PruneDb(ctx context.Context, db *pgx.Conn, keepAfterCompletion time.Duration, batchLimit int, clock clock.Clock) error { +func PruneDb(ctx *armadacontext.Context, db *pgx.Conn, keepAfterCompletion time.Duration, batchLimit int, clock clock.Clock) error { now := clock.Now() cutOffTime := now.Add(-keepAfterCompletion) totalJobsToDelete, err := createJobIdsToDeleteTempTable(ctx, db, cutOffTime) @@ -60,10 +61,10 @@ func PruneDb(ctx context.Context, db *pgx.Conn, keepAfterCompletion time.Duratio } // Returns total number of jobs to delete -func createJobIdsToDeleteTempTable(ctx context.Context, db *pgx.Conn, cutOffTime time.Time) (int, error) { +func createJobIdsToDeleteTempTable(ctx *armadacontext.Context, db *pgx.Conn, cutOffTime time.Time) (int, error) { _, err := db.Exec(ctx, ` CREATE TEMP TABLE job_ids_to_delete AS ( - SELECT job_id FROM job + SELECT job_id FROM job WHERE last_transition_time < $1 )`, cutOffTime) if err != nil { @@ -77,7 +78,7 @@ func createJobIdsToDeleteTempTable(ctx context.Context, db *pgx.Conn, cutOffTime return totalJobsToDelete, nil } -func deleteBatch(ctx context.Context, tx pgx.Tx, batchLimit int) (int, error) { +func deleteBatch(ctx *armadacontext.Context, tx pgx.Tx, batchLimit int) (int, error) { _, err := tx.Exec(ctx, "INSERT INTO batch (job_id) SELECT job_id FROM job_ids_to_delete LIMIT $1;", batchLimit) if err != nil { return -1, err diff --git a/internal/lookoutv2/pruner/pruner_test.go b/internal/lookoutv2/pruner/pruner_test.go index a88274c316a..3df18c0cf05 100644 --- a/internal/lookoutv2/pruner/pruner_test.go +++ b/internal/lookoutv2/pruner/pruner_test.go @@ -1,7 +1,6 @@ package pruner import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" @@ -111,7 +111,7 @@ func TestPruneDb(t *testing.T) { converter := instructions.NewInstructionConverter(metrics.Get(), "armadaproject.io/", &compress.NoOpCompressor{}, true) store := lookoutdb.NewLookoutDb(db, metrics.Get(), 3, 10) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Minute) defer cancel() for _, tj := range tc.jobs { runId := uuid.NewString() @@ -156,7 +156,7 @@ func TestPruneDb(t *testing.T) { func selectStringSet(t *testing.T, db *pgxpool.Pool, query string) map[string]bool { t.Helper() - rows, err := db.Query(context.TODO(), query) + rows, err := db.Query(armadacontext.TODO(), query) assert.NoError(t, err) var ss []string for rows.Next() { diff --git a/internal/lookoutv2/repository/getjobrunerror.go b/internal/lookoutv2/repository/getjobrunerror.go index 467da22ec1a..b878c9291fb 100644 --- a/internal/lookoutv2/repository/getjobrunerror.go +++ b/internal/lookoutv2/repository/getjobrunerror.go @@ -1,18 +1,17 @@ package repository import ( - "context" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" ) type GetJobRunErrorRepository interface { - GetJobRunError(ctx context.Context, runId string) (string, error) + GetJobRunError(ctx *armadacontext.Context, runId string) (string, error) } type SqlGetJobRunErrorRepository struct { @@ -27,7 +26,7 @@ func NewSqlGetJobRunErrorRepository(db *pgxpool.Pool, decompressor compress.Deco } } -func (r *SqlGetJobRunErrorRepository) GetJobRunError(ctx context.Context, runId string) (string, error) { +func (r *SqlGetJobRunErrorRepository) GetJobRunError(ctx *armadacontext.Context, runId string) (string, error) { var rawBytes []byte err := r.db.QueryRow(ctx, "SELECT error FROM job_run WHERE run_id = $1 AND error IS NOT NULL", runId).Scan(&rawBytes) if err != nil { diff --git a/internal/lookoutv2/repository/getjobrunerror_test.go b/internal/lookoutv2/repository/getjobrunerror_test.go index 274de5e6d40..4bf2854929d 100644 --- a/internal/lookoutv2/repository/getjobrunerror_test.go +++ b/internal/lookoutv2/repository/getjobrunerror_test.go @@ -1,12 +1,12 @@ package repository import ( - "context" "testing" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutingesterv2/instructions" @@ -34,7 +34,7 @@ func TestGetJobRunError(t *testing.T) { ApiJob() repo := NewSqlGetJobRunErrorRepository(db, &compress.NoOpDecompressor{}) - result, err := repo.GetJobRunError(context.TODO(), runId) + result, err := repo.GetJobRunError(armadacontext.TODO(), runId) assert.NoError(t, err) assert.Equal(t, expected, result) } @@ -46,7 +46,7 @@ func TestGetJobRunError(t *testing.T) { func TestGetJobRunErrorNotFound(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobRunErrorRepository(db, &compress.NoOpDecompressor{}) - _, err := repo.GetJobRunError(context.TODO(), runId) + _, err := repo.GetJobRunError(armadacontext.TODO(), runId) assert.Error(t, err) return nil }) diff --git a/internal/lookoutv2/repository/getjobs.go b/internal/lookoutv2/repository/getjobs.go index eac6cc0aaf5..cce2550d2b1 100644 --- a/internal/lookoutv2/repository/getjobs.go +++ b/internal/lookoutv2/repository/getjobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "database/sql" "fmt" "sort" @@ -12,13 +11,14 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutv2/model" ) type GetJobsRepository interface { - GetJobs(ctx context.Context, filters []*model.Filter, order *model.Order, skip int, take int) (*GetJobsResult, error) + GetJobs(ctx *armadacontext.Context, filters []*model.Filter, order *model.Order, skip int, take int) (*GetJobsResult, error) } type SqlGetJobsRepository struct { @@ -77,7 +77,7 @@ func NewSqlGetJobsRepository(db *pgxpool.Pool) *SqlGetJobsRepository { } } -func (r *SqlGetJobsRepository) GetJobs(ctx context.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, skip int, take int) (*GetJobsResult, error) { +func (r *SqlGetJobsRepository) GetJobs(ctx *armadacontext.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, skip int, take int) (*GetJobsResult, error) { var jobRows []*jobRow var runRows []*runRow var annotationRows []*annotationRow @@ -243,7 +243,7 @@ func getJobRunTime(run *model.Run) (time.Time, error) { return time.Time{}, errors.Errorf("error when getting run time for run with id %s", run.RunId) } -func makeJobRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow, error) { +func makeJobRows(ctx *armadacontext.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow, error) { query := fmt.Sprintf(` SELECT j.job_id, @@ -302,7 +302,7 @@ func makeJobRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*jobRow return rows, nil } -func makeRunRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*runRow, error) { +func makeRunRows(ctx *armadacontext.Context, tx pgx.Tx, tmpTableName string) ([]*runRow, error) { query := fmt.Sprintf(` SELECT jr.job_id, @@ -347,7 +347,7 @@ func makeRunRows(ctx context.Context, tx pgx.Tx, tmpTableName string) ([]*runRow return rows, nil } -func makeAnnotationRows(ctx context.Context, tx pgx.Tx, tempTableName string) ([]*annotationRow, error) { +func makeAnnotationRows(ctx *armadacontext.Context, tx pgx.Tx, tempTableName string) ([]*annotationRow, error) { query := fmt.Sprintf(` SELECT ual.job_id, diff --git a/internal/lookoutv2/repository/getjobs_test.go b/internal/lookoutv2/repository/getjobs_test.go index 3c28805c198..d5b45a5cae7 100644 --- a/internal/lookoutv2/repository/getjobs_test.go +++ b/internal/lookoutv2/repository/getjobs_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -11,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/util" @@ -77,7 +77,7 @@ func TestGetJobsSingle(t *testing.T) { Job() repo := NewSqlGetJobsRepository(db) - result, err := repo.GetJobs(context.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) + result, err := repo.GetJobs(armadacontext.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) assert.NoError(t, err) assert.Len(t, result.Jobs, 1) assert.Equal(t, 1, result.Count) @@ -105,7 +105,7 @@ func TestGetJobsMultipleRuns(t *testing.T) { // Runs should be sorted from oldest -> newest repo := NewSqlGetJobsRepository(db) - result, err := repo.GetJobs(context.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) + result, err := repo.GetJobs(armadacontext.TODO(), []*model.Filter{}, false, &model.Order{}, 0, 1) assert.NoError(t, err) assert.Len(t, result.Jobs, 1) assert.Equal(t, 1, result.Count) @@ -119,7 +119,7 @@ func TestOrderByUnsupportedField(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -140,7 +140,7 @@ func TestOrderByUnsupportedDirection(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -192,7 +192,7 @@ func TestGetJobsOrderByJobId(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -212,7 +212,7 @@ func TestGetJobsOrderByJobId(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -259,7 +259,7 @@ func TestGetJobsOrderBySubmissionTime(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -279,7 +279,7 @@ func TestGetJobsOrderBySubmissionTime(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -330,7 +330,7 @@ func TestGetJobsOrderByLastTransitionTime(t *testing.T) { t.Run("ascending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -350,7 +350,7 @@ func TestGetJobsOrderByLastTransitionTime(t *testing.T) { t.Run("descending order", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -377,7 +377,7 @@ func TestFilterByUnsupportedField(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "someField", Match: model.MatchExact, @@ -400,7 +400,7 @@ func TestFilterByUnsupportedMatch(t *testing.T) { repo := NewSqlGetJobsRepository(db) _, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobId", Match: model.MatchLessThan, @@ -443,7 +443,7 @@ func TestGetJobsById(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobId", Match: model.MatchExact, @@ -499,7 +499,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchExact, @@ -518,7 +518,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchStartsWith, @@ -542,7 +542,7 @@ func TestGetJobsByQueue(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "queue", Match: model.MatchContains, @@ -604,7 +604,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchExact, @@ -623,7 +623,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchStartsWith, @@ -647,7 +647,7 @@ func TestGetJobsByJobSet(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "jobSet", Match: model.MatchContains, @@ -709,7 +709,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchExact, @@ -728,7 +728,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchStartsWith, @@ -752,7 +752,7 @@ func TestGetJobsByOwner(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "owner", Match: model.MatchContains, @@ -817,7 +817,7 @@ func TestGetJobsByState(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "state", Match: model.MatchExact, @@ -836,7 +836,7 @@ func TestGetJobsByState(t *testing.T) { t.Run("anyOf", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "state", Match: model.MatchAnyOf, @@ -923,7 +923,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "annotation-key-1", Match: model.MatchExact, @@ -943,7 +943,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exact, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -971,7 +971,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("startsWith, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1000,7 +1000,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("contains, multiple annotations", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1029,7 +1029,7 @@ func TestGetJobsByAnnotation(t *testing.T) { t.Run("exists", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "annotation-key-1", @@ -1093,7 +1093,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchExact, @@ -1112,7 +1112,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchGreaterThan, @@ -1135,7 +1135,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchLessThan, @@ -1158,7 +1158,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchGreaterThanOrEqualTo, @@ -1182,7 +1182,7 @@ func TestGetJobsByCpu(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "cpu", Match: model.MatchLessThanOrEqualTo, @@ -1246,7 +1246,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchExact, @@ -1265,7 +1265,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchGreaterThan, @@ -1288,7 +1288,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchLessThan, @@ -1311,7 +1311,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchGreaterThanOrEqualTo, @@ -1335,7 +1335,7 @@ func TestGetJobsByMemory(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "memory", Match: model.MatchLessThanOrEqualTo, @@ -1399,7 +1399,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchExact, @@ -1418,7 +1418,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchGreaterThan, @@ -1441,7 +1441,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchLessThan, @@ -1464,7 +1464,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchGreaterThanOrEqualTo, @@ -1488,7 +1488,7 @@ func TestGetJobsByEphemeralStorage(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "ephemeralStorage", Match: model.MatchLessThanOrEqualTo, @@ -1552,7 +1552,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchExact, @@ -1571,7 +1571,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchGreaterThan, @@ -1594,7 +1594,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchLessThan, @@ -1617,7 +1617,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchGreaterThanOrEqualTo, @@ -1641,7 +1641,7 @@ func TestGetJobsByGpu(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "gpu", Match: model.MatchLessThanOrEqualTo, @@ -1705,7 +1705,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchExact, @@ -1724,7 +1724,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("greaterThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchGreaterThan, @@ -1747,7 +1747,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("lessThan", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchLessThan, @@ -1770,7 +1770,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("greaterThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchGreaterThanOrEqualTo, @@ -1794,7 +1794,7 @@ func TestGetJobsByPriority(t *testing.T) { t.Run("lessThanOrEqualTo", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priority", Match: model.MatchLessThanOrEqualTo, @@ -1865,7 +1865,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("exact", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchExact, @@ -1884,7 +1884,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("startsWith", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchStartsWith, @@ -1908,7 +1908,7 @@ func TestGetJobsByPriorityClass(t *testing.T) { t.Run("contains", func(t *testing.T) { result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{{ Field: "priorityClass", Match: model.MatchContains, @@ -1957,7 +1957,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 3 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1977,7 +1977,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 7 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1997,7 +1997,7 @@ func TestGetJobsSkip(t *testing.T) { skip := 13 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -2057,7 +2057,7 @@ func TestGetJobsComplex(t *testing.T) { skip := 8 take := 5 result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -2121,7 +2121,7 @@ func TestGetJobsActiveJobSet(t *testing.T) { repo := NewSqlGetJobsRepository(db) result, err := repo.GetJobs( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, true, &model.Order{ diff --git a/internal/lookoutv2/repository/getjobspec.go b/internal/lookoutv2/repository/getjobspec.go index 60c6ac41cd1..55799249f35 100644 --- a/internal/lookoutv2/repository/getjobspec.go +++ b/internal/lookoutv2/repository/getjobspec.go @@ -1,20 +1,19 @@ package repository import ( - "context" - "github.com/gogo/protobuf/proto" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/pkg/api" ) type GetJobSpecRepository interface { - GetJobSpec(ctx context.Context, jobId string) (*api.Job, error) + GetJobSpec(ctx *armadacontext.Context, jobId string) (*api.Job, error) } type SqlGetJobSpecRepository struct { @@ -29,7 +28,7 @@ func NewSqlGetJobSpecRepository(db *pgxpool.Pool, decompressor compress.Decompre } } -func (r *SqlGetJobSpecRepository) GetJobSpec(ctx context.Context, jobId string) (*api.Job, error) { +func (r *SqlGetJobSpecRepository) GetJobSpec(ctx *armadacontext.Context, jobId string) (*api.Job, error) { var rawBytes []byte err := r.db.QueryRow(ctx, "SELECT job_spec FROM job WHERE job_id = $1", jobId).Scan(&rawBytes) if err != nil { diff --git a/internal/lookoutv2/repository/getjobspec_test.go b/internal/lookoutv2/repository/getjobspec_test.go index d7e00d83671..b13a897e8c4 100644 --- a/internal/lookoutv2/repository/getjobspec_test.go +++ b/internal/lookoutv2/repository/getjobspec_test.go @@ -1,12 +1,12 @@ package repository import ( - "context" "testing" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/lookoutingesterv2/instructions" @@ -42,7 +42,7 @@ func TestGetJobSpec(t *testing.T) { ApiJob() repo := NewSqlGetJobSpecRepository(db, &compress.NoOpDecompressor{}) - result, err := repo.GetJobSpec(context.TODO(), jobId) + result, err := repo.GetJobSpec(armadacontext.TODO(), jobId) assert.NoError(t, err) assertApiJobsEquivalent(t, job, result) return nil @@ -53,7 +53,7 @@ func TestGetJobSpec(t *testing.T) { func TestGetJobSpecError(t *testing.T) { err := lookout.WithLookoutDb(func(db *pgxpool.Pool) error { repo := NewSqlGetJobSpecRepository(db, &compress.NoOpDecompressor{}) - _, err := repo.GetJobSpec(context.TODO(), jobId) + _, err := repo.GetJobSpec(armadacontext.TODO(), jobId) assert.Error(t, err) return nil }) diff --git a/internal/lookoutv2/repository/groupjobs.go b/internal/lookoutv2/repository/groupjobs.go index dd80976dcd6..20dcb5adb0a 100644 --- a/internal/lookoutv2/repository/groupjobs.go +++ b/internal/lookoutv2/repository/groupjobs.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "strings" @@ -9,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/lookoutv2/model" @@ -22,7 +22,7 @@ type GroupByResult struct { type GroupJobsRepository interface { GroupBy( - ctx context.Context, + ctx *armadacontext.Context, filters []*model.Filter, order *model.Order, groupedField string, @@ -47,7 +47,7 @@ func NewSqlGroupJobsRepository(db *pgxpool.Pool) *SqlGroupJobsRepository { } func (r *SqlGroupJobsRepository) GroupBy( - ctx context.Context, + ctx *armadacontext.Context, filters []*model.Filter, activeJobSets bool, order *model.Order, diff --git a/internal/lookoutv2/repository/groupjobs_test.go b/internal/lookoutv2/repository/groupjobs_test.go index 1f255029f8c..b2bd04d5d03 100644 --- a/internal/lookoutv2/repository/groupjobs_test.go +++ b/internal/lookoutv2/repository/groupjobs_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/pointer" @@ -39,7 +39,7 @@ func TestGroupByQueue(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -98,7 +98,7 @@ func TestGroupByJobSet(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -165,7 +165,7 @@ func TestGroupByState(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -330,7 +330,7 @@ func TestGroupByWithFilters(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -452,7 +452,7 @@ func TestGroupJobsWithMaxSubmittedTime(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -552,7 +552,7 @@ func TestGroupJobsWithAvgLastTransitionTime(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -652,7 +652,7 @@ func TestGroupJobsWithAllStateCounts(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -774,7 +774,7 @@ func TestGroupJobsWithFilteredStateCounts(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: stateField, @@ -898,7 +898,7 @@ func TestGroupJobsComplex(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -997,7 +997,7 @@ func TestGroupByAnnotation(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1112,7 +1112,7 @@ func TestGroupByAnnotationWithFiltersAndAggregates(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{ { Field: "queue", @@ -1212,7 +1212,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 3 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1242,7 +1242,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 7 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1272,7 +1272,7 @@ func TestGroupJobsSkip(t *testing.T) { skip := 13 take := 5 result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1306,7 +1306,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid field", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1325,7 +1325,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("invalid field", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1344,7 +1344,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid annotation", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1364,7 +1364,7 @@ func TestGroupJobsValidation(t *testing.T) { t.Run("valid annotation with same name as column", func(t *testing.T) { _, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, false, &model.Order{ @@ -1427,7 +1427,7 @@ func TestGroupByActiveJobSets(t *testing.T) { repo := NewSqlGroupJobsRepository(db) result, err := repo.GroupBy( - context.TODO(), + armadacontext.TODO(), []*model.Filter{}, true, &model.Order{ diff --git a/internal/lookoutv2/repository/util.go b/internal/lookoutv2/repository/util.go index d250f3844dc..62143df1f37 100644 --- a/internal/lookoutv2/repository/util.go +++ b/internal/lookoutv2/repository/util.go @@ -1,7 +1,6 @@ package repository import ( - "context" "fmt" "strings" "time" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database/lookout" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/ingest" @@ -586,8 +586,8 @@ func (js *JobSimulator) Build() *JobSimulator { EventSequences: []*armadaevents.EventSequence{eventSequence}, MessageIds: []pulsar.MessageID{pulsarutils.NewMessageId(1)}, } - instructionSet := js.converter.Convert(context.TODO(), eventSequenceWithIds) - err := js.store.Store(context.TODO(), instructionSet) + instructionSet := js.converter.Convert(armadacontext.TODO(), eventSequenceWithIds) + err := js.store.Store(armadacontext.TODO(), instructionSet) if err != nil { log.WithError(err).Error("Simulator failed to store job in database") } diff --git a/internal/pulsartest/watch.go b/internal/pulsartest/watch.go index cbe6e5834fa..210916cf7e1 100644 --- a/internal/pulsartest/watch.go +++ b/internal/pulsartest/watch.go @@ -1,13 +1,13 @@ package pulsartest import ( - "context" "fmt" "log" "os" "github.com/sanity-io/litter" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/pulsarutils" ) @@ -17,12 +17,12 @@ func (a *App) Watch() error { defer a.Reader.Close() for a.Reader.HasNext() { - msg, err := a.Reader.Next(context.Background()) + msg, err := a.Reader.Next(armadacontext.Background()) if err != nil { log.Fatal(err) } - ctx := context.Background() + ctx := armadacontext.Background() msgId := pulsarutils.New(msg.ID().LedgerID(), msg.ID().EntryID(), msg.ID().PartitionIdx(), msg.ID().BatchIdx()) diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index 2e6f2779504..a31eba85f5e 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -8,10 +8,10 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/types" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -81,9 +81,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns return errors.WithStack(err) } - ctx := stream.Context() - log := ctxlogrus.Extract(ctx) - log = log.WithField("executor", req.ExecutorId) + ctx := armadacontext.WithLogField(armadacontext.FromGrpcCtx(stream.Context()), "executor", req.ExecutorId) executor := srv.executorFromLeaseRequest(ctx, req) if err := srv.executorRepository.StoreExecutor(ctx, executor); err != nil { @@ -105,7 +103,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns if err != nil { return err } - log.Infof( + ctx.Log.Infof( "executor currently has %d job runs; sending %d cancellations and %d new runs", len(requestRuns), len(runsToCancel), len(newRuns), ) @@ -216,19 +214,19 @@ func setPriorityClassName(podSpec *armadaevents.PodSpecWithAvoidList, priorityCl } // ReportEvents publishes all events to Pulsar. The events are compacted for more efficient publishing. -func (srv *ExecutorApi) ReportEvents(ctx context.Context, list *executorapi.EventList) (*types.Empty, error) { +func (srv *ExecutorApi) ReportEvents(grpcCtx context.Context, list *executorapi.EventList) (*types.Empty, error) { + ctx := armadacontext.FromGrpcCtx(grpcCtx) err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSizeBytes, schedulers.Pulsar) return &types.Empty{}, err } // executorFromLeaseRequest extracts a schedulerobjects.Executor from the request. -func (srv *ExecutorApi) executorFromLeaseRequest(ctx context.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { - log := ctxlogrus.Extract(ctx) +func (srv *ExecutorApi) executorFromLeaseRequest(ctx *armadacontext.Context, req *executorapi.LeaseRequest) *schedulerobjects.Executor { nodes := make([]*schedulerobjects.Node, 0, len(req.Nodes)) now := srv.clock.Now().UTC() for _, nodeInfo := range req.Nodes { if node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, now); err != nil { - logging.WithStacktrace(log, err).Warnf( + logging.WithStacktrace(ctx.Log, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetExecutorId(), ) } else { diff --git a/internal/scheduler/api_test.go b/internal/scheduler/api_test.go index 77a3c52f7a9..f388a5129d4 100644 --- a/internal/scheduler/api_test.go +++ b/internal/scheduler/api_test.go @@ -1,6 +1,7 @@ package scheduler import ( + "context" "testing" "time" @@ -10,10 +11,10 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/context" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/mocks" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -165,7 +166,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) ctrl := gomock.NewController(t) mockPulsarProducer := mocks.NewMockProducer(ctrl) mockJobRepository := schedulermocks.NewMockJobRepository(ctrl) @@ -179,11 +180,11 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { // set up mocks mockStream.EXPECT().Context().Return(ctx).AnyTimes() mockStream.EXPECT().Recv().Return(tc.request, nil).Times(1) - mockExecutorRepository.EXPECT().StoreExecutor(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, executor *schedulerobjects.Executor) error { + mockExecutorRepository.EXPECT().StoreExecutor(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { assert.Equal(t, tc.expectedExecutor, executor) return nil }).Times(1) - mockLegacyExecutorRepository.EXPECT().StoreExecutor(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, executor *schedulerobjects.Executor) error { + mockLegacyExecutorRepository.EXPECT().StoreExecutor(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { assert.Equal(t, tc.expectedExecutor, executor) return nil }).Times(1) @@ -304,7 +305,7 @@ func TestExecutorApi_Publish(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) ctrl := gomock.NewController(t) mockPulsarProducer := mocks.NewMockProducer(ctrl) mockJobRepository := schedulermocks.NewMockJobRepository(ctrl) diff --git a/internal/scheduler/database/db.go b/internal/scheduler/database/db.go index 5af3de156f4..8f9fc5e6de2 100644 --- a/internal/scheduler/database/db.go +++ b/internal/scheduler/database/db.go @@ -7,8 +7,8 @@ package database import ( "context" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type DBTX interface { diff --git a/internal/scheduler/database/db_pruner.go b/internal/scheduler/database/db_pruner.go index 728c3c9b71b..9ea8075a40d 100644 --- a/internal/scheduler/database/db_pruner.go +++ b/internal/scheduler/database/db_pruner.go @@ -28,7 +28,7 @@ func PruneDb(ctx ctx.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion // Insert the ids of all jobs we want to delete into a tmp table _, err = db.Exec(ctx, `CREATE TEMP TABLE rows_to_delete AS ( - SELECT job_id FROM jobs + SELECT job_id FROM jobs WHERE last_modified < $1 AND (succeeded = TRUE OR failed = TRUE OR cancelled = TRUE))`, cutOffTime) if err != nil { diff --git a/internal/scheduler/database/db_pruner_test.go b/internal/scheduler/database/db_pruner_test.go index bd1165ed2d3..1a30c200463 100644 --- a/internal/scheduler/database/db_pruner_test.go +++ b/internal/scheduler/database/db_pruner_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" commonutil "github.com/armadaproject/armada/internal/common/util" ) @@ -108,7 +108,7 @@ func TestPruneDb_RemoveJobs(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := WithTestDb(func(_ *Queries, db *pgxpool.Pool) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() testClock := clock.NewFakeClock(baseTime) @@ -186,7 +186,7 @@ func TestPruneDb_RemoveMarkers(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := WithTestDb(func(_ *Queries, db *pgxpool.Pool) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() testClock := clock.NewFakeClock(baseTime) @@ -220,7 +220,7 @@ func TestPruneDb_RemoveMarkers(t *testing.T) { // Removes the triggers that auto-set serial and last_update_time as // we need to manipulate these as part of the test -func removeTriggers(ctx context.Context, db *pgxpool.Pool) error { +func removeTriggers(ctx *armadacontext.Context, db *pgxpool.Pool) error { triggers := map[string]string{ "jobs": "next_serial_on_insert_jobs", "runs": "next_serial_on_insert_runs", diff --git a/internal/scheduler/database/executor_repository.go b/internal/scheduler/database/executor_repository.go index c2da2442e54..ec50db20126 100644 --- a/internal/scheduler/database/executor_repository.go +++ b/internal/scheduler/database/executor_repository.go @@ -1,13 +1,13 @@ package database import ( - "context" "time" "github.com/gogo/protobuf/proto" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -15,11 +15,11 @@ import ( // ExecutorRepository is an interface to be implemented by structs which provide executor information type ExecutorRepository interface { // GetExecutors returns all known executors, regardless of their last heartbeat time - GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) + GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) // GetLastUpdateTimes returns a map of executor name -> last heartbeat time - GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) + GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) // StoreExecutor persists the latest executor state - StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error + StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error } // PostgresExecutorRepository is an implementation of ExecutorRepository that stores its state in postgres @@ -40,7 +40,7 @@ func NewPostgresExecutorRepository(db *pgxpool.Pool) *PostgresExecutorRepository } // GetExecutors returns all known executors, regardless of their last heartbeat time -func (r *PostgresExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (r *PostgresExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { queries := New(r.db) requests, err := queries.SelectAllExecutors(ctx) if err != nil { @@ -59,7 +59,7 @@ func (r *PostgresExecutorRepository) GetExecutors(ctx context.Context) ([]*sched } // GetLastUpdateTimes returns a map of executor name -> last heartbeat time -func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { queries := New(r.db) rows, err := queries.SelectExecutorUpdateTimes(ctx) if err != nil { @@ -74,7 +74,7 @@ func (r *PostgresExecutorRepository) GetLastUpdateTimes(ctx context.Context) (ma } // StoreExecutor persists the latest executor state -func (r *PostgresExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (r *PostgresExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { queries := New(r.db) bytes, err := proto.Marshal(executor) if err != nil { diff --git a/internal/scheduler/database/executor_repository_test.go b/internal/scheduler/database/executor_repository_test.go index 2d7bd206512..76a0e14c9f9 100644 --- a/internal/scheduler/database/executor_repository_test.go +++ b/internal/scheduler/database/executor_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -53,7 +53,7 @@ func TestExecutorRepository_LoadAndSave(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withExecutorRepository(func(repo *PostgresExecutorRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) @@ -106,7 +106,7 @@ func TestExecutorRepository_GetLastUpdateTimes(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withExecutorRepository(func(repo *PostgresExecutorRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) diff --git a/internal/scheduler/database/job_repository.go b/internal/scheduler/database/job_repository.go index c4eaf606099..ebc08d03230 100644 --- a/internal/scheduler/database/job_repository.go +++ b/internal/scheduler/database/job_repository.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "github.com/google/uuid" @@ -9,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" protoutil "github.com/armadaproject/armada/internal/common/proto" @@ -35,24 +35,24 @@ type JobRunLease struct { type JobRepository interface { // FetchJobUpdates returns all jobs and job dbRuns that have been updated after jobSerial and jobRunSerial respectively // These updates are guaranteed to be consistent with each other - FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) + FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) // FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is // keyed by job run id. Any dbRuns which don't have errors wil be absent from the map. - FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) + FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) // CountReceivedPartitions returns a count of the number of partition messages present in the database corresponding // to the provided groupId. This is used by the scheduler to determine if the database represents the state of // pulsar after a given point in time. - CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) + CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) // FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active // Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled - FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) + FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded - FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) + FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) } // PostgresJobRepository is an implementation of JobRepository that stores its state in postgres @@ -72,7 +72,7 @@ func NewPostgresJobRepository(db *pgxpool.Pool, batchSize int32) *PostgresJobRep // FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is // keyed by job run id. Any dbRuns which don't have errors wil be absent from the map. -func (r *PostgresJobRepository) FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (r *PostgresJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { if len(runIds) == 0 { return map[uuid.UUID]*armadaevents.Error{}, nil } @@ -125,7 +125,7 @@ func (r *PostgresJobRepository) FetchJobRunErrors(ctx context.Context, runIds [] // FetchJobUpdates returns all jobs and job dbRuns that have been updated after jobSerial and jobRunSerial respectively // These updates are guaranteed to be consistent with each other -func (r *PostgresJobRepository) FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) { +func (r *PostgresJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]Job, []Run, error) { var updatedJobs []Job = nil var updatedRuns []Run = nil @@ -180,7 +180,7 @@ func (r *PostgresJobRepository) FetchJobUpdates(ctx context.Context, jobSerial i // FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active // Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled -func (r *PostgresJobRepository) FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { +func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { var inactiveRuns []uuid.UUID err := pgx.BeginTxFunc(ctx, r.db, pgx.TxOptions{ IsoLevel: pgx.ReadCommitted, @@ -221,7 +221,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx context.Context, runIds []u // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded -func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { +func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { if maxResults == 0 { return []*JobRunLease{}, nil } @@ -272,7 +272,7 @@ func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor // CountReceivedPartitions returns a count of the number of partition messages present in the database corresponding // to the provided groupId. This is used by the scheduler to determine if the database represents the state of // pulsar after a given point in time. -func (r *PostgresJobRepository) CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (r *PostgresJobRepository) CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { queries := New(r.db) count, err := queries.CountGroup(ctx, groupId) if err != nil { @@ -300,7 +300,7 @@ func fetch[T hasSerial](from int64, batchSize int32, fetchBatch func(int64) ([]T } // Insert all run ids into a tmp table. The name of the table is returned -func insertRunIdsToTmpTable(ctx context.Context, tx pgx.Tx, runIds []uuid.UUID) (string, error) { +func insertRunIdsToTmpTable(ctx *armadacontext.Context, tx pgx.Tx, runIds []uuid.UUID) (string, error) { tmpTable := database.UniqueTableName("job_runs") _, err := tx.Exec(ctx, fmt.Sprintf("CREATE TEMPORARY TABLE %s (run_id uuid) ON COMMIT DROP", tmpTable)) diff --git a/internal/scheduler/database/job_repository_test.go b/internal/scheduler/database/job_repository_test.go index b236618185b..771d887b17d 100644 --- a/internal/scheduler/database/job_repository_test.go +++ b/internal/scheduler/database/job_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "testing" "time" @@ -13,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" protoutil "github.com/armadaproject/armada/internal/common/proto" @@ -84,7 +84,7 @@ func TestFetchJobUpdates(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "jobs", tc.dbJobs) @@ -187,7 +187,7 @@ func TestFetchJobRunErrors(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "job_run_errors", tc.errorsInDb) require.NoError(t, err) @@ -222,7 +222,7 @@ func TestCountReceivedPartitions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) markers := make([]Marker, tc.numPartitions) groupId := uuid.New() @@ -357,7 +357,7 @@ func TestFindInactiveRuns(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 500*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "runs", tc.dbRuns) @@ -487,7 +487,7 @@ func TestFetchJobRunLeases(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { err := withJobRepository(func(repo *PostgresJobRepository) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) // Set up db err := database.UpsertWithTransaction(ctx, repo.db, "jobs", tc.dbJobs) @@ -553,7 +553,7 @@ func withJobRepository(action func(repository *PostgresJobRepository) error) err }) } -func insertMarkers(ctx context.Context, markers []Marker, db *pgxpool.Pool) error { +func insertMarkers(ctx *armadacontext.Context, markers []Marker, db *pgxpool.Pool) error { for _, marker := range markers { _, err := db.Exec(ctx, "INSERT INTO markers VALUES ($1, $2)", marker.GroupID, marker.PartitionID) if err != nil { diff --git a/internal/scheduler/database/redis_executor_repository.go b/internal/scheduler/database/redis_executor_repository.go index 989710a69da..ef775ff7f75 100644 --- a/internal/scheduler/database/redis_executor_repository.go +++ b/internal/scheduler/database/redis_executor_repository.go @@ -1,7 +1,6 @@ package database import ( - "context" "fmt" "time" @@ -9,6 +8,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" + "github.com/armadaproject/armada/internal/common/armadacontext" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -29,7 +29,7 @@ func NewRedisExecutorRepository(db redis.UniversalClient, schedulerName string) } } -func (r *RedisExecutorRepository) GetExecutors(_ context.Context) ([]*schedulerobjects.Executor, error) { +func (r *RedisExecutorRepository) GetExecutors(_ *armadacontext.Context) ([]*schedulerobjects.Executor, error) { result, err := r.db.HGetAll(r.executorsKey).Result() if err != nil { return nil, errors.Wrap(err, "Error retrieving executors from redis") @@ -47,12 +47,12 @@ func (r *RedisExecutorRepository) GetExecutors(_ context.Context) ([]*schedulero return executors, nil } -func (r *RedisExecutorRepository) GetLastUpdateTimes(_ context.Context) (map[string]time.Time, error) { +func (r *RedisExecutorRepository) GetLastUpdateTimes(_ *armadacontext.Context) (map[string]time.Time, error) { // We could implement this in a very inefficient way, but I don't believe it's needed so panic for now panic("GetLastUpdateTimes is not implemented") } -func (r *RedisExecutorRepository) StoreExecutor(_ context.Context, executor *schedulerobjects.Executor) error { +func (r *RedisExecutorRepository) StoreExecutor(_ *armadacontext.Context, executor *schedulerobjects.Executor) error { data, err := proto.Marshal(executor) if err != nil { return errors.Wrap(err, "Error marshalling executor proto") diff --git a/internal/scheduler/database/redis_executor_repository_test.go b/internal/scheduler/database/redis_executor_repository_test.go index 6fb48d66c49..bf5b0ea9629 100644 --- a/internal/scheduler/database/redis_executor_repository_test.go +++ b/internal/scheduler/database/redis_executor_repository_test.go @@ -1,7 +1,6 @@ package database import ( - "context" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -53,7 +53,7 @@ func TestRedisExecutorRepository_LoadAndSave(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { withRedisExecutorRepository(func(repo *RedisExecutorRepository) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() for _, executor := range tc.executors { err := repo.StoreExecutor(ctx, executor) diff --git a/internal/scheduler/database/util.go b/internal/scheduler/database/util.go index d6539a2a743..618c32c8efb 100644 --- a/internal/scheduler/database/util.go +++ b/internal/scheduler/database/util.go @@ -1,7 +1,6 @@ package database import ( - "context" "embed" _ "embed" "time" @@ -9,13 +8,14 @@ import ( "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" ) //go:embed migrations/*.sql var fs embed.FS -func Migrate(ctx context.Context, db database.Querier) error { +func Migrate(ctx *armadacontext.Context, db database.Querier) error { start := time.Now() migrations, err := database.ReadMigrations(fs, "migrations") if err != nil { diff --git a/internal/scheduler/gang_scheduler.go b/internal/scheduler/gang_scheduler.go index ffca7be9f8e..fb9a3add118 100644 --- a/internal/scheduler/gang_scheduler.go +++ b/internal/scheduler/gang_scheduler.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "fmt" "github.com/hashicorp/go-memdb" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -38,7 +38,7 @@ func (sch *GangScheduler) SkipUnsuccessfulSchedulingKeyCheck() { sch.skipUnsuccessfulSchedulingKeyCheck = true } -func (sch *GangScheduler) Schedule(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) Schedule(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { // Exit immediately if this is a new gang and we've exceeded any round limits. // // Because this check occurs before adding the gctx to the sctx, @@ -109,7 +109,7 @@ func (sch *GangScheduler) Schedule(ctx context.Context, gctx *schedulercontext.G return sch.trySchedule(ctx, gctx) } -func (sch *GangScheduler) trySchedule(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) trySchedule(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { // If no node uniformity constraint, try scheduling across all nodes. if gctx.NodeUniformityLabel == "" { return sch.tryScheduleGang(ctx, gctx) @@ -176,7 +176,7 @@ func (sch *GangScheduler) trySchedule(ctx context.Context, gctx *schedulercontex return sch.tryScheduleGang(ctx, gctx) } -func (sch *GangScheduler) tryScheduleGang(ctx context.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) tryScheduleGang(ctx *armadacontext.Context, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { txn := sch.nodeDb.Txn(true) defer txn.Abort() ok, unschedulableReason, err = sch.tryScheduleGangWithTxn(ctx, txn, gctx) @@ -186,7 +186,7 @@ func (sch *GangScheduler) tryScheduleGang(ctx context.Context, gctx *schedulerco return } -func (sch *GangScheduler) tryScheduleGangWithTxn(ctx context.Context, txn *memdb.Txn, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { +func (sch *GangScheduler) tryScheduleGangWithTxn(ctx *armadacontext.Context, txn *memdb.Txn, gctx *schedulercontext.GangSchedulingContext) (ok bool, unschedulableReason string, err error) { if ok, err = sch.nodeDb.ScheduleManyWithTxn(txn, gctx.JobSchedulingContexts); err != nil { return } else if !ok { diff --git a/internal/scheduler/gang_scheduler_test.go b/internal/scheduler/gang_scheduler_test.go index e5fbafad703..cc79703d2b2 100644 --- a/internal/scheduler/gang_scheduler_test.go +++ b/internal/scheduler/gang_scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -372,7 +372,7 @@ func TestGangScheduler(t *testing.T) { for i, gang := range tc.Gangs { jctxs := schedulercontext.JobSchedulingContextsFromJobs(testfixtures.TestPriorityClasses, gang) gctx := schedulercontext.NewGangSchedulingContext(jctxs) - ok, reason, err := sch.Schedule(context.Background(), gctx) + ok, reason, err := sch.Schedule(armadacontext.Background(), gctx) require.NoError(t, err) if ok { require.Empty(t, reason) diff --git a/internal/scheduler/jobiteration.go b/internal/scheduler/jobiteration.go index 7b232edc141..04dd63a6490 100644 --- a/internal/scheduler/jobiteration.go +++ b/internal/scheduler/jobiteration.go @@ -1,13 +1,12 @@ package scheduler import ( - "context" "sync" "golang.org/x/exp/maps" "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/internal/scheduler/interfaces" ) @@ -136,7 +135,7 @@ func (repo *InMemoryJobRepository) GetExistingJobsByIds(jobIds []string) ([]inte return rv, nil } -func (repo *InMemoryJobRepository) GetJobIterator(ctx context.Context, queue string) (JobIterator, error) { +func (repo *InMemoryJobRepository) GetJobIterator(ctx *armadacontext.Context, queue string) (JobIterator, error) { repo.mu.Lock() defer repo.mu.Unlock() return NewInMemoryJobIterator(slices.Clone(repo.jobsByQueue[queue])), nil @@ -145,14 +144,14 @@ func (repo *InMemoryJobRepository) GetJobIterator(ctx context.Context, queue str // QueuedJobsIterator is an iterator over all jobs in a queue. // It lazily loads jobs in batches from Redis asynch. type QueuedJobsIterator struct { - ctx context.Context + ctx *armadacontext.Context err error c chan interfaces.LegacySchedulerJob } -func NewQueuedJobsIterator(ctx context.Context, queue string, repo JobRepository) (*QueuedJobsIterator, error) { +func NewQueuedJobsIterator(ctx *armadacontext.Context, queue string, repo JobRepository) (*QueuedJobsIterator, error) { batchSize := 16 - g, ctx := errgroup.WithContext(ctx) + g, ctx := armadacontext.ErrGroup(ctx) it := &QueuedJobsIterator{ ctx: ctx, c: make(chan interfaces.LegacySchedulerJob, 2*batchSize), // 2x batchSize to load one batch async. @@ -190,7 +189,7 @@ func (it *QueuedJobsIterator) Next() (interfaces.LegacySchedulerJob, error) { // queuedJobsIteratorLoader loads jobs from Redis lazily. // Used with QueuedJobsIterator. -func queuedJobsIteratorLoader(ctx context.Context, jobIds []string, ch chan interfaces.LegacySchedulerJob, batchSize int, repo JobRepository) error { +func queuedJobsIteratorLoader(ctx *armadacontext.Context, jobIds []string, ch chan interfaces.LegacySchedulerJob, batchSize int, repo JobRepository) error { defer close(ch) batch := make([]string, batchSize) for i, jobId := range jobIds { diff --git a/internal/scheduler/jobiteration_test.go b/internal/scheduler/jobiteration_test.go index 42133f0ba05..a5990fa3fc4 100644 --- a/internal/scheduler/jobiteration_test.go +++ b/internal/scheduler/jobiteration_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/interfaces" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" @@ -87,7 +88,7 @@ func TestMultiJobsIterator_TwoQueues(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() its := make([]JobIterator, 3) for i, queue := range []string{"A", "B", "C"} { it, err := NewQueuedJobsIterator(ctx, queue, repo) @@ -121,7 +122,7 @@ func TestQueuedJobsIterator_OneQueue(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -146,7 +147,7 @@ func TestQueuedJobsIterator_ExceedsBufferSize(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -171,7 +172,7 @@ func TestQueuedJobsIterator_ManyJobs(t *testing.T) { expected = append(expected, job.Id) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -200,7 +201,7 @@ func TestCreateQueuedJobsIterator_TwoQueues(t *testing.T) { repo.Enqueue(job) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -223,7 +224,7 @@ func TestCreateQueuedJobsIterator_RespectsTimeout(t *testing.T) { repo.Enqueue(job) } - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Millisecond) time.Sleep(20 * time.Millisecond) defer cancel() it, err := NewQueuedJobsIterator(ctx, "A", repo) @@ -248,7 +249,7 @@ func TestCreateQueuedJobsIterator_NilOnEmpty(t *testing.T) { repo.Enqueue(job) } - ctx := context.Background() + ctx := armadacontext.Background() it, err := NewQueuedJobsIterator(ctx, "A", repo) if !assert.NoError(t, err) { return @@ -291,7 +292,7 @@ func (repo *mockJobRepository) Enqueue(job *api.Job) { repo.jobsById[job.Id] = job } -func (repo *mockJobRepository) GetJobIterator(ctx context.Context, queue string) (JobIterator, error) { +func (repo *mockJobRepository) GetJobIterator(ctx *armadacontext.Context, queue string) (JobIterator, error) { return NewQueuedJobsIterator(ctx, queue, repo) } diff --git a/internal/scheduler/leader.go b/internal/scheduler/leader.go index a0c8b8a85f6..0482184a7a8 100644 --- a/internal/scheduler/leader.go +++ b/internal/scheduler/leader.go @@ -6,12 +6,12 @@ import ( "sync/atomic" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" coordinationv1client "k8s.io/client-go/kubernetes/typed/coordination/v1" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "github.com/armadaproject/armada/internal/common/armadacontext" schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" ) @@ -23,7 +23,7 @@ type LeaderController interface { // Returns true if the token is a leader and false otherwise ValidateToken(tok LeaderToken) bool // Run starts the controller. This is a blocking call which will return when the provided context is cancelled - Run(ctx context.Context) error + Run(ctx *armadacontext.Context) error // GetLeaderReport returns a report about the current leader GetLeaderReport() LeaderReport } @@ -85,14 +85,14 @@ func (lc *StandaloneLeaderController) ValidateToken(tok LeaderToken) bool { return false } -func (lc *StandaloneLeaderController) Run(ctx context.Context) error { +func (lc *StandaloneLeaderController) Run(ctx *armadacontext.Context) error { return nil } // LeaseListener allows clients to listen for lease events. type LeaseListener interface { // Called when the client has started leading. - onStartedLeading(context.Context) + onStartedLeading(*armadacontext.Context) // Called when the client has stopped leading, onStoppedLeading() } @@ -138,16 +138,14 @@ func (lc *KubernetesLeaderController) ValidateToken(tok LeaderToken) bool { // Run starts the controller. // This is a blocking call that returns when the provided context is cancelled. -func (lc *KubernetesLeaderController) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "KubernetesLeaderController") +func (lc *KubernetesLeaderController) Run(ctx *armadacontext.Context) error { for { select { case <-ctx.Done(): return ctx.Err() default: lock := lc.getNewLock() - log.Infof("attempting to become leader") + ctx.Log.Infof("attempting to become leader") leaderelection.RunOrDie(ctx, leaderelection.LeaderElectionConfig{ Lock: lock, ReleaseOnCancel: true, @@ -156,14 +154,14 @@ func (lc *KubernetesLeaderController) Run(ctx context.Context) error { RetryPeriod: lc.config.RetryPeriod, Callbacks: leaderelection.LeaderCallbacks{ OnStartedLeading: func(c context.Context) { - log.Infof("I am now leader") + ctx.Log.Infof("I am now leader") lc.token.Store(NewLeaderToken()) for _, listener := range lc.listeners { listener.onStartedLeading(ctx) } }, OnStoppedLeading: func() { - log.Infof("I am no longer leader") + ctx.Log.Infof("I am no longer leader") lc.token.Store(InvalidLeaderToken()) for _, listener := range lc.listeners { listener.onStoppedLeading() @@ -176,7 +174,7 @@ func (lc *KubernetesLeaderController) Run(ctx context.Context) error { }, }, }) - log.Infof("leader election round finished") + ctx.Log.Infof("leader election round finished") } } } diff --git a/internal/scheduler/leader_client_test.go b/internal/scheduler/leader_client_test.go index e8909356402..31ba46a8913 100644 --- a/internal/scheduler/leader_client_test.go +++ b/internal/scheduler/leader_client_test.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "testing" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/configuration" "github.com/armadaproject/armada/pkg/client" ) @@ -91,7 +91,7 @@ func (f *FakeLeaderController) ValidateToken(tok LeaderToken) bool { return f.IsCurrentlyLeader } -func (f *FakeLeaderController) Run(ctx context.Context) error { +func (f *FakeLeaderController) Run(_ *armadacontext.Context) error { return nil } diff --git a/internal/scheduler/leader_metrics.go b/internal/scheduler/leader_metrics.go index cc02157504e..d5d4e62f535 100644 --- a/internal/scheduler/leader_metrics.go +++ b/internal/scheduler/leader_metrics.go @@ -1,11 +1,11 @@ package scheduler import ( - "context" "sync" "github.com/prometheus/client_golang/prometheus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/metrics" ) @@ -29,7 +29,7 @@ func NewLeaderStatusMetricsCollector(currentInstanceName string) *LeaderStatusMe } } -func (l *LeaderStatusMetricsCollector) onStartedLeading(context.Context) { +func (l *LeaderStatusMetricsCollector) onStartedLeading(*armadacontext.Context) { l.lock.Lock() defer l.lock.Unlock() diff --git a/internal/scheduler/leader_metrics_test.go b/internal/scheduler/leader_metrics_test.go index fec5d4e5d08..8132179afbd 100644 --- a/internal/scheduler/leader_metrics_test.go +++ b/internal/scheduler/leader_metrics_test.go @@ -1,11 +1,12 @@ package scheduler import ( - "context" "testing" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) const testInstanceName = "instance-1" @@ -31,7 +32,7 @@ func TestLeaderStatusMetrics_HandlesLeaderChanges(t *testing.T) { assert.Equal(t, actual[0], isNotLeaderMetric) // start leading - collector.onStartedLeading(context.Background()) + collector.onStartedLeading(armadacontext.Background()) actual = getCurrentMetrics(collector) assert.Len(t, actual, 1) assert.Equal(t, actual[0], isLeaderMetric) diff --git a/internal/scheduler/leader_proxying_reports_server_test.go b/internal/scheduler/leader_proxying_reports_server_test.go index 5fc1874d210..2b83a02da28 100644 --- a/internal/scheduler/leader_proxying_reports_server_test.go +++ b/internal/scheduler/leader_proxying_reports_server_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -48,7 +49,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetJobReports(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) @@ -113,7 +114,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetSchedulingReport(t *testing.T) } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) @@ -178,7 +179,7 @@ func TestLeaderProxyingSchedulingReportsServer_GetQueueReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, clientProvider, jobReportsServer, jobReportsClient := setupLeaderProxyingSchedulerReportsServerTest(t) diff --git a/internal/scheduler/leader_test.go b/internal/scheduler/leader_test.go index 1790c9518b5..17fb468b0cf 100644 --- a/internal/scheduler/leader_test.go +++ b/internal/scheduler/leader_test.go @@ -12,6 +12,7 @@ import ( v1 "k8s.io/api/coordination/v1" "k8s.io/utils/pointer" + "github.com/armadaproject/armada/internal/common/armadacontext" schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" ) @@ -108,7 +109,7 @@ func TestK8sLeaderController_BecomingLeader(t *testing.T) { controller := NewKubernetesLeaderController(testLeaderConfig(), client) testListener := NewTestLeaseListener(controller) controller.RegisterListener(testListener) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) go func() { err := controller.Run(ctx) assert.ErrorIs(t, err, context.Canceled) @@ -184,7 +185,7 @@ func (t *TestLeaseListener) GetMessages() []LeaderToken { return append([]LeaderToken(nil), t.tokens...) } -func (t *TestLeaseListener) onStartedLeading(_ context.Context) { +func (t *TestLeaseListener) onStartedLeading(_ *armadacontext.Context) { t.handleNewToken() } diff --git a/internal/scheduler/metrics.go b/internal/scheduler/metrics.go index 15da0d6c478..a7fb2f08c78 100644 --- a/internal/scheduler/metrics.go +++ b/internal/scheduler/metrics.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "strings" "sync/atomic" "time" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/scheduler/database" @@ -76,7 +76,7 @@ func NewMetricsCollector( } // Run enters s a loop which updates the metrics every refreshPeriod until the supplied context is cancelled -func (c *MetricsCollector) Run(ctx context.Context) error { +func (c *MetricsCollector) Run(ctx *armadacontext.Context) error { ticker := c.clock.NewTicker(c.refreshPeriod) log.Infof("Will update metrics every %s", c.refreshPeriod) for { @@ -108,7 +108,7 @@ func (c *MetricsCollector) Collect(metrics chan<- prometheus.Metric) { } } -func (c *MetricsCollector) refresh(ctx context.Context) error { +func (c *MetricsCollector) refresh(ctx *armadacontext.Context) error { log.Debugf("Refreshing prometheus metrics") start := time.Now() queueMetrics, err := c.updateQueueMetrics(ctx) @@ -125,7 +125,7 @@ func (c *MetricsCollector) refresh(ctx context.Context) error { return nil } -func (c *MetricsCollector) updateQueueMetrics(ctx context.Context) ([]prometheus.Metric, error) { +func (c *MetricsCollector) updateQueueMetrics(ctx *armadacontext.Context) ([]prometheus.Metric, error) { queues, err := c.queueRepository.GetAllQueues() if err != nil { return nil, err @@ -212,7 +212,7 @@ type clusterMetricKey struct { nodeType string } -func (c *MetricsCollector) updateClusterMetrics(ctx context.Context) ([]prometheus.Metric, error) { +func (c *MetricsCollector) updateClusterMetrics(ctx *armadacontext.Context) ([]prometheus.Metric, error) { executors, err := c.executorRepository.GetExecutors(ctx) if err != nil { return nil, err diff --git a/internal/scheduler/metrics_test.go b/internal/scheduler/metrics_test.go index 52c89eb6641..0bbcd9090c7 100644 --- a/internal/scheduler/metrics_test.go +++ b/internal/scheduler/metrics_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "time" @@ -12,6 +11,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" "github.com/armadaproject/armada/internal/scheduler/database" "github.com/armadaproject/armada/internal/scheduler/jobdb" @@ -86,7 +86,7 @@ func TestMetricsCollector_TestCollect_QueueMetrics(t *testing.T) { t.Run(name, func(t *testing.T) { ctrl := gomock.NewController(t) testClock := clock.NewFakeClock(testfixtures.BaseTime) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // set up job db with initial jobs @@ -236,7 +236,7 @@ func TestMetricsCollector_TestCollect_ClusterMetrics(t *testing.T) { t.Run(name, func(t *testing.T) { ctrl := gomock.NewController(t) testClock := clock.NewFakeClock(testfixtures.BaseTime) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // set up job db with initial jobs @@ -303,7 +303,7 @@ type MockPoolAssigner struct { poolsById map[string]string } -func (m MockPoolAssigner) Refresh(_ context.Context) error { +func (m MockPoolAssigner) Refresh(_ *armadacontext.Context) error { return nil } diff --git a/internal/scheduler/mocks/mock_repositories.go b/internal/scheduler/mocks/mock_repositories.go index 9a8f6efee1a..c2924402b9b 100644 --- a/internal/scheduler/mocks/mock_repositories.go +++ b/internal/scheduler/mocks/mock_repositories.go @@ -5,10 +5,10 @@ package schedulermocks import ( - context "context" reflect "reflect" time "time" + armadacontext "github.com/armadaproject/armada/internal/common/armadacontext" database "github.com/armadaproject/armada/internal/scheduler/database" schedulerobjects "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" armadaevents "github.com/armadaproject/armada/pkg/armadaevents" @@ -40,7 +40,7 @@ func (m *MockExecutorRepository) EXPECT() *MockExecutorRepositoryMockRecorder { } // GetExecutors mocks base method. -func (m *MockExecutorRepository) GetExecutors(arg0 context.Context) ([]*schedulerobjects.Executor, error) { +func (m *MockExecutorRepository) GetExecutors(arg0 *armadacontext.Context) ([]*schedulerobjects.Executor, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetExecutors", arg0) ret0, _ := ret[0].([]*schedulerobjects.Executor) @@ -55,7 +55,7 @@ func (mr *MockExecutorRepositoryMockRecorder) GetExecutors(arg0 interface{}) *go } // GetLastUpdateTimes mocks base method. -func (m *MockExecutorRepository) GetLastUpdateTimes(arg0 context.Context) (map[string]time.Time, error) { +func (m *MockExecutorRepository) GetLastUpdateTimes(arg0 *armadacontext.Context) (map[string]time.Time, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLastUpdateTimes", arg0) ret0, _ := ret[0].(map[string]time.Time) @@ -70,7 +70,7 @@ func (mr *MockExecutorRepositoryMockRecorder) GetLastUpdateTimes(arg0 interface{ } // StoreExecutor mocks base method. -func (m *MockExecutorRepository) StoreExecutor(arg0 context.Context, arg1 *schedulerobjects.Executor) error { +func (m *MockExecutorRepository) StoreExecutor(arg0 *armadacontext.Context, arg1 *schedulerobjects.Executor) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StoreExecutor", arg0, arg1) ret0, _ := ret[0].(error) @@ -145,7 +145,7 @@ func (m *MockJobRepository) EXPECT() *MockJobRepositoryMockRecorder { } // CountReceivedPartitions mocks base method. -func (m *MockJobRepository) CountReceivedPartitions(arg0 context.Context, arg1 uuid.UUID) (uint32, error) { +func (m *MockJobRepository) CountReceivedPartitions(arg0 *armadacontext.Context, arg1 uuid.UUID) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CountReceivedPartitions", arg0, arg1) ret0, _ := ret[0].(uint32) @@ -160,7 +160,7 @@ func (mr *MockJobRepositoryMockRecorder) CountReceivedPartitions(arg0, arg1 inte } // FetchJobRunErrors mocks base method. -func (m *MockJobRepository) FetchJobRunErrors(arg0 context.Context, arg1 []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (m *MockJobRepository) FetchJobRunErrors(arg0 *armadacontext.Context, arg1 []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobRunErrors", arg0, arg1) ret0, _ := ret[0].(map[uuid.UUID]*armadaevents.Error) @@ -175,7 +175,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobRunErrors(arg0, arg1 interface{ } // FetchJobRunLeases mocks base method. -func (m *MockJobRepository) FetchJobRunLeases(arg0 context.Context, arg1 string, arg2 uint, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { +func (m *MockJobRepository) FetchJobRunLeases(arg0 *armadacontext.Context, arg1 string, arg2 uint, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobRunLeases", arg0, arg1, arg2, arg3) ret0, _ := ret[0].([]*database.JobRunLease) @@ -190,7 +190,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobRunLeases(arg0, arg1, arg2, arg } // FetchJobUpdates mocks base method. -func (m *MockJobRepository) FetchJobUpdates(arg0 context.Context, arg1, arg2 int64) ([]database.Job, []database.Run, error) { +func (m *MockJobRepository) FetchJobUpdates(arg0 *armadacontext.Context, arg1, arg2 int64) ([]database.Job, []database.Run, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobUpdates", arg0, arg1, arg2) ret0, _ := ret[0].([]database.Job) @@ -206,7 +206,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobUpdates(arg0, arg1, arg2 interf } // FindInactiveRuns mocks base method. -func (m *MockJobRepository) FindInactiveRuns(arg0 context.Context, arg1 []uuid.UUID) ([]uuid.UUID, error) { +func (m *MockJobRepository) FindInactiveRuns(arg0 *armadacontext.Context, arg1 []uuid.UUID) ([]uuid.UUID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FindInactiveRuns", arg0, arg1) ret0, _ := ret[0].([]uuid.UUID) diff --git a/internal/scheduler/pool_assigner.go b/internal/scheduler/pool_assigner.go index 94aa07e4908..9ff1f9b140c 100644 --- a/internal/scheduler/pool_assigner.go +++ b/internal/scheduler/pool_assigner.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "time" "github.com/gogo/protobuf/proto" @@ -10,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -22,7 +22,7 @@ import ( // PoolAssigner allows jobs to be assigned to a pool // Note that this is intended only for use with metrics calculation type PoolAssigner interface { - Refresh(ctx context.Context) error + Refresh(ctx *armadacontext.Context) error AssignPool(j *jobdb.Job) (string, error) } @@ -71,7 +71,7 @@ func NewPoolAssigner(executorTimeout time.Duration, } // Refresh updates executor state -func (p *DefaultPoolAssigner) Refresh(ctx context.Context) error { +func (p *DefaultPoolAssigner) Refresh(ctx *armadacontext.Context) error { executors, err := p.executorRepository.GetExecutors(ctx) executorsByPool := map[string][]*executor{} poolByExecutorId := map[string]string{} diff --git a/internal/scheduler/pool_assigner_test.go b/internal/scheduler/pool_assigner_test.go index f2508295e65..7734b6195be 100644 --- a/internal/scheduler/pool_assigner_test.go +++ b/internal/scheduler/pool_assigner_test.go @@ -1,17 +1,16 @@ package scheduler import ( - "context" "testing" "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "k8s.io/apimachinery/pkg/util/clock" - "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/jobdb" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" @@ -48,7 +47,7 @@ func TestPoolAssigner_AssignPool(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index ebf2c35b390..fd0c0d9e079 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -1,12 +1,10 @@ package scheduler import ( - "context" "fmt" "math/rand" "time" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/hashicorp/go-memdb" "github.com/pkg/errors" "golang.org/x/exp/maps" @@ -14,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" @@ -108,9 +107,7 @@ func (sch *PreemptingQueueScheduler) EnableNewPreemptionStrategy() { // Schedule // - preempts jobs belonging to queues with total allocation above their fair share and // - schedules new jobs belonging to queues with total allocation less than their fair share. -func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "PreemptingQueueScheduler") +func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { defer func() { sch.schedulingContext.Finished = time.Now() }() @@ -125,23 +122,18 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Evict preemptible jobs. totalCost := sch.schedulingContext.TotalCost() evictorResult, inMemoryJobRepo, err := sch.evict( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "evict for resource balancing"), - ), + armadacontext.WithLogField(ctx, "stage", "evict for resource balancing"), NewNodeEvictor( sch.jobRepo, sch.schedulingContext.PriorityClasses, sch.nodeEvictionProbability, - func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { + func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) return false } if job.GetNodeSelector() == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) + ctx.Log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) return false } if qctx, ok := sch.schedulingContext.QueueSchedulingContexts[job.GetQueue()]; ok { @@ -168,10 +160,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Re-schedule evicted jobs/schedule new jobs. schedulerResult, err := sch.schedule( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "re-schedule after balancing eviction"), - ), + armadacontext.WithLogField(ctx, "stage", "re-schedule after balancing eviction"), inMemoryJobRepo, sch.jobRepo, ) @@ -189,10 +178,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Evict jobs on oversubscribed nodes. evictorResult, inMemoryJobRepo, err = sch.evict( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "evict oversubscribed"), - ), + armadacontext.WithLogField(ctx, "stage", "evict oversubscribed"), NewOversubscribedEvictor( sch.jobRepo, sch.schedulingContext.PriorityClasses, @@ -226,10 +212,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe // Since no new jobs are considered in this round, the scheduling key check brings no benefit. sch.SkipUnsuccessfulSchedulingKeyCheck() schedulerResult, err = sch.schedule( - ctxlogrus.ToContext( - ctx, - log.WithField("stage", "schedule after oversubscribed eviction"), - ), + armadacontext.WithLogField(ctx, "stage", "schedule after oversubscribed eviction"), inMemoryJobRepo, // Only evicted jobs should be scheduled in this round, // so we provide an empty repo for queued jobs. @@ -258,10 +241,10 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe return nil, err } if s := JobsSummary(preemptedJobs); s != "" { - log.Infof("preempting running jobs; %s", s) + ctx.Log.Infof("preempting running jobs; %s", s) } if s := JobsSummary(scheduledJobs); s != "" { - log.Infof("scheduling new jobs; %s", s) + ctx.Log.Infof("scheduling new jobs; %s", s) } if sch.enableAssertions { err := sch.assertions( @@ -282,7 +265,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx context.Context) (*SchedulerRe }, nil } -func (sch *PreemptingQueueScheduler) evict(ctx context.Context, evictor *Evictor) (*EvictorResult, *InMemoryJobRepository, error) { +func (sch *PreemptingQueueScheduler) evict(ctx *armadacontext.Context, evictor *Evictor) (*EvictorResult, *InMemoryJobRepository, error) { if evictor == nil { return &EvictorResult{}, NewInMemoryJobRepository(sch.schedulingContext.PriorityClasses), nil } @@ -348,7 +331,7 @@ func (sch *PreemptingQueueScheduler) evict(ctx context.Context, evictor *Evictor // When evicting jobs, gangs may have been partially evicted. // Here, we evict all jobs in any gang for which at least one job was already evicted. -func (sch *PreemptingQueueScheduler) evictGangs(ctx context.Context, txn *memdb.Txn, previousEvictorResult *EvictorResult) (*EvictorResult, error) { +func (sch *PreemptingQueueScheduler) evictGangs(ctx *armadacontext.Context, txn *memdb.Txn, previousEvictorResult *EvictorResult) (*EvictorResult, error) { gangJobIds, gangNodeIds, err := sch.collectIdsForGangEviction(previousEvictorResult.EvictedJobsById) if err != nil { return nil, err @@ -512,7 +495,7 @@ func (q MinimalQueue) GetWeight() float64 { // addEvictedJobsToNodeDb adds evicted jobs to the NodeDb. // Needed to enable the nodeDb accounting for these when preempting. -func addEvictedJobsToNodeDb(ctx context.Context, sctx *schedulercontext.SchedulingContext, nodeDb *nodedb.NodeDb, inMemoryJobRepo *InMemoryJobRepository) error { +func addEvictedJobsToNodeDb(ctx *armadacontext.Context, sctx *schedulercontext.SchedulingContext, nodeDb *nodedb.NodeDb, inMemoryJobRepo *InMemoryJobRepository) error { gangItByQueue := make(map[string]*QueuedGangIterator) for _, qctx := range sctx.QueueSchedulingContexts { jobIt, err := inMemoryJobRepo.GetJobIterator(ctx, qctx.Queue) @@ -552,7 +535,7 @@ func addEvictedJobsToNodeDb(ctx context.Context, sctx *schedulercontext.Scheduli return nil } -func (sch *PreemptingQueueScheduler) schedule(ctx context.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*SchedulerResult, error) { +func (sch *PreemptingQueueScheduler) schedule(ctx *armadacontext.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*SchedulerResult, error) { jobIteratorByQueue := make(map[string]JobIterator) for _, qctx := range sch.schedulingContext.QueueSchedulingContexts { evictedIt, err := inMemoryJobRepo.GetJobIterator(ctx, qctx.Queue) @@ -717,9 +700,9 @@ func (sch *PreemptingQueueScheduler) assertions( type Evictor struct { jobRepo JobRepository priorityClasses map[string]types.PriorityClass - nodeFilter func(context.Context, *nodedb.Node) bool - jobFilter func(context.Context, interfaces.LegacySchedulerJob) bool - postEvictFunc func(context.Context, interfaces.LegacySchedulerJob, *nodedb.Node) + nodeFilter func(*armadacontext.Context, *nodedb.Node) bool + jobFilter func(*armadacontext.Context, interfaces.LegacySchedulerJob) bool + postEvictFunc func(*armadacontext.Context, interfaces.LegacySchedulerJob, *nodedb.Node) } type EvictorResult struct { @@ -735,7 +718,7 @@ func NewNodeEvictor( jobRepo JobRepository, priorityClasses map[string]types.PriorityClass, perNodeEvictionProbability float64, - jobFilter func(context.Context, interfaces.LegacySchedulerJob) bool, + jobFilter func(*armadacontext.Context, interfaces.LegacySchedulerJob) bool, random *rand.Rand, ) *Evictor { if perNodeEvictionProbability <= 0 { @@ -747,7 +730,7 @@ func NewNodeEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { return len(node.AllocatedByJobId) > 0 && random.Float64() < perNodeEvictionProbability }, jobFilter: jobFilter, @@ -769,11 +752,11 @@ func NewFilteredEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { shouldEvict := nodeIdsToEvict[node.Id] return shouldEvict }, - jobFilter: func(_ context.Context, job interfaces.LegacySchedulerJob) bool { + jobFilter: func(_ *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { shouldEvict := jobIdsToEvict[job.GetId()] return shouldEvict }, @@ -804,7 +787,7 @@ func NewOversubscribedEvictor( return &Evictor{ jobRepo: jobRepo, priorityClasses: priorityClasses, - nodeFilter: func(_ context.Context, node *nodedb.Node) bool { + nodeFilter: func(_ *armadacontext.Context, node *nodedb.Node) bool { overSubscribedPriorities = make(map[int32]bool) for p, rl := range node.AllocatableByPriority { if p < 0 { @@ -820,10 +803,9 @@ func NewOversubscribedEvictor( } return len(overSubscribedPriorities) > 0 && random.Float64() < perNodeEvictionProbability }, - jobFilter: func(ctx context.Context, job interfaces.LegacySchedulerJob) bool { + jobFilter: func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - log := ctxlogrus.Extract(ctx) - log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) return false } priorityClassName := job.GetPriorityClassName() @@ -844,7 +826,7 @@ func NewOversubscribedEvictor( // Any node for which nodeFilter returns false is skipped. // Any job for which jobFilter returns true is evicted (if the node was not skipped). // If a job was evicted from a node, postEvictFunc is called with the corresponding job and node. -func (evi *Evictor) Evict(ctx context.Context, it nodedb.NodeIterator) (*EvictorResult, error) { +func (evi *Evictor) Evict(ctx *armadacontext.Context, it nodedb.NodeIterator) (*EvictorResult, error) { var jobFilter func(job interfaces.LegacySchedulerJob) bool if evi.jobFilter != nil { jobFilter = func(job interfaces.LegacySchedulerJob) bool { return evi.jobFilter(ctx, job) } @@ -898,12 +880,11 @@ func (evi *Evictor) Evict(ctx context.Context, it nodedb.NodeIterator) (*Evictor // TODO: This is only necessary for jobs not scheduled in this cycle. // Since jobs scheduled in this cycle can be re-scheduled onto another node without triggering a preemption. -func defaultPostEvictFunc(ctx context.Context, job interfaces.LegacySchedulerJob, node *nodedb.Node) { +func defaultPostEvictFunc(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob, node *nodedb.Node) { // Add annotation indicating to the scheduler this this job was evicted. annotations := job.GetAnnotations() if annotations == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("error evicting job %s: annotations not initialised", job.GetId()) + ctx.Log.Errorf("error evicting job %s: annotations not initialised", job.GetId()) } else { annotations[schedulerconfig.IsEvictedAnnotation] = "true" } @@ -911,8 +892,7 @@ func defaultPostEvictFunc(ctx context.Context, job interfaces.LegacySchedulerJob // Add node selector ensuring this job is only re-scheduled onto the node it was evicted from. nodeSelector := job.GetNodeSelector() if nodeSelector == nil { - log := ctxlogrus.Extract(ctx) - log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) + ctx.Log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) } else { nodeSelector[schedulerconfig.NodeIdLabel] = node.Id } diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index dc84cd225ae..84538cdccc2 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -1,13 +1,11 @@ package scheduler import ( - "context" "fmt" "math/rand" "testing" "time" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" @@ -55,7 +54,7 @@ func TestEvictOversubscribed(t *testing.T) { nil, ) it := NewInMemoryNodeIterator([]*nodedb.Node{entry}) - result, err := evictor.Evict(context.Background(), it) + result, err := evictor.Evict(armadacontext.Background(), it) require.NoError(t, err) prioritiesByName := configuration.PriorityByPriorityClassName(testfixtures.TestPriorityClasses) @@ -1459,7 +1458,7 @@ func TestPreemptingQueueScheduler(t *testing.T) { if tc.SchedulingConfig.EnableNewPreemptionStrategy { sch.EnableNewPreemptionStrategy() } - result, err := sch.Schedule(ctxlogrus.ToContext(context.Background(), log)) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(t, err) jobIdsByGangId = sch.jobIdsByGangId gangIdByJobId = sch.gangIdByJobId @@ -1734,7 +1733,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nil, nil, ) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(b, err) require.Equal(b, 0, len(result.PreemptedJobs)) @@ -1790,7 +1789,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { nil, nil, ) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(b, err) // We expect the system to be in steady-state, i.e., no preempted/scheduled jobs. diff --git a/internal/scheduler/proxying_reports_server_test.go b/internal/scheduler/proxying_reports_server_test.go index 0dc81b54bf9..98f7c11fa97 100644 --- a/internal/scheduler/proxying_reports_server_test.go +++ b/internal/scheduler/proxying_reports_server_test.go @@ -1,13 +1,13 @@ package scheduler import ( - "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -24,7 +24,7 @@ func TestProxyingSchedulingReportsServer_GetJobReports(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) @@ -62,7 +62,7 @@ func TestProxyingSchedulingReportsServer_GetSchedulingReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) @@ -100,7 +100,7 @@ func TestProxyingSchedulingReportsServer_GetQueueReport(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() sut, jobReportsClient := setupProxyingSchedulerReportsServerTest(t) diff --git a/internal/scheduler/publisher.go b/internal/scheduler/publisher.go index 0ae0595303b..0b308141961 100644 --- a/internal/scheduler/publisher.go +++ b/internal/scheduler/publisher.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "strconv" "sync" @@ -13,6 +12,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" @@ -28,12 +28,12 @@ const ( type Publisher interface { // PublishMessages will publish the supplied messages. A LeaderToken is provided and the // implementor may decide whether to publish based on the status of this token - PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error + PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error // PublishMarkers publishes a single marker message for each Pulsar partition. Each marker // massage contains the supplied group id, which allows all marker messages for a given call // to be identified. The uint32 returned is the number of messages published - PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) + PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) } // PulsarPublisher is the default implementation of Publisher @@ -77,7 +77,7 @@ func NewPulsarPublisher( // PublishMessages publishes all event sequences to pulsar. Event sequences for a given jobset will be combined into // single event sequences up to maxMessageBatchSize. -func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error { +func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error { sequences := eventutil.CompactEventSequences(events) sequences, err := eventutil.LimitSequencesByteSize(sequences, p.maxMessageBatchSize, true) if err != nil { @@ -104,7 +104,7 @@ func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadae // Send messages if shouldPublish() { log.Debugf("Am leader so will publish") - sendCtx, cancel := context.WithTimeout(ctx, p.pulsarSendTimeout) + sendCtx, cancel := armadacontext.WithTimeout(ctx, p.pulsarSendTimeout) errored := false for _, msg := range msgs { p.producer.SendAsync(sendCtx, msg, func(_ pulsar.MessageID, _ *pulsar.ProducerMessage, err error) { @@ -128,7 +128,7 @@ func (p *PulsarPublisher) PublishMessages(ctx context.Context, events []*armadae // PublishMarkers sends one pulsar message (containing an armadaevents.PartitionMarker) to each partition // of the producer's Pulsar topic. -func (p *PulsarPublisher) PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (p *PulsarPublisher) PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { for i := 0; i < p.numPartitions; i++ { pm := &armadaevents.PartitionMarker{ GroupId: armadaevents.ProtoUuidFromUuid(groupId), diff --git a/internal/scheduler/publisher_test.go b/internal/scheduler/publisher_test.go index a524f9e26b9..6ecb200d416 100644 --- a/internal/scheduler/publisher_test.go +++ b/internal/scheduler/publisher_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "math" "testing" @@ -15,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/mocks" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/pkg/armadaevents" @@ -89,7 +89,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) mockPulsarClient := mocks.NewMockClient(ctrl) @@ -106,7 +106,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { mockPulsarProducer. EXPECT(). SendAsync(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, msg *pulsar.ProducerMessage, callback func(pulsar.MessageID, *pulsar.ProducerMessage, error)) { + DoAndReturn(func(_ *armadacontext.Context, msg *pulsar.ProducerMessage, callback func(pulsar.MessageID, *pulsar.ProducerMessage, error)) { es := &armadaevents.EventSequence{} err := proto.Unmarshal(msg.Payload, es) require.NoError(t, err) @@ -177,7 +177,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { mockPulsarProducer. EXPECT(). Send(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) { + DoAndReturn(func(_ *armadacontext.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) { numPublished++ key, ok := msg.Properties[explicitPartitionKey] if ok { @@ -190,7 +190,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { }).AnyTimes() options := pulsar.ProducerOptions{Topic: topic} - ctx := context.TODO() + ctx := armadacontext.TODO() publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) require.NoError(t, err) diff --git a/internal/scheduler/queue_scheduler.go b/internal/scheduler/queue_scheduler.go index 825c9f26bfb..cf03c7af3fc 100644 --- a/internal/scheduler/queue_scheduler.go +++ b/internal/scheduler/queue_scheduler.go @@ -2,13 +2,13 @@ package scheduler import ( "container/heap" - "context" "reflect" "time" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -60,7 +60,7 @@ func (sch *QueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { sch.gangScheduler.SkipUnsuccessfulSchedulingKeyCheck() } -func (sch *QueueScheduler) Schedule(ctx context.Context) (*SchedulerResult, error) { +func (sch *QueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { nodeIdByJobId := make(map[string]string) scheduledJobs := make([]interfaces.LegacySchedulerJob, 0) for { diff --git a/internal/scheduler/queue_scheduler_test.go b/internal/scheduler/queue_scheduler_test.go index cbbc537e495..3832db7ceba 100644 --- a/internal/scheduler/queue_scheduler_test.go +++ b/internal/scheduler/queue_scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "testing" @@ -13,6 +12,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" @@ -512,14 +512,14 @@ func TestQueueScheduler(t *testing.T) { ) jobIteratorByQueue := make(map[string]JobIterator) for queue := range tc.PriorityFactorByQueue { - it, err := jobRepo.GetJobIterator(context.Background(), queue) + it, err := jobRepo.GetJobIterator(armadacontext.Background(), queue) require.NoError(t, err) jobIteratorByQueue[queue] = it } sch, err := NewQueueScheduler(sctx, constraints, nodeDb, jobIteratorByQueue) require.NoError(t, err) - result, err := sch.Schedule(context.Background()) + result, err := sch.Schedule(armadacontext.Background()) require.NoError(t, err) // Check that the right jobs got scheduled. diff --git a/internal/scheduler/reports_test.go b/internal/scheduler/reports_test.go index b3c8f568d38..fcc0837188a 100644 --- a/internal/scheduler/reports_test.go +++ b/internal/scheduler/reports_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/api/resource" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" @@ -159,7 +159,7 @@ func TestAddGetSchedulingContext(t *testing.T) { func TestTestAddGetSchedulingContextConcurrency(t *testing.T) { repo, err := NewSchedulingContextRepository(10) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), time.Second) defer cancel() for _, executorId := range []string{"foo", "bar"} { go func(executorId string) { @@ -202,7 +202,7 @@ func TestReportDoesNotExist(t *testing.T) { require.NoError(t, err) err = repo.AddSchedulingContext(testSchedulingContext("executor-01")) require.NoError(t, err) - ctx := context.Background() + ctx := armadacontext.Background() queue := "queue-does-not-exist" jobId := util.NewULID() diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 88c43b14a0f..ccc4d998ff5 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -1,19 +1,18 @@ package scheduler import ( - "context" "fmt" "time" "github.com/gogo/protobuf/proto" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/scheduler/database" @@ -116,40 +115,37 @@ func NewScheduler( } // Run enters the scheduling loop, which will continue until ctx is cancelled. -func (s *Scheduler) Run(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("service", "scheduler") - ctx = ctxlogrus.ToContext(ctx, log) - log.Infof("starting scheduler with cycle time %s", s.cyclePeriod) - defer log.Info("scheduler stopped") +func (s *Scheduler) Run(ctx *armadacontext.Context) error { + ctx.Log.Infof("starting scheduler with cycle time %s", s.cyclePeriod) + defer ctx.Log.Info("scheduler stopped") // JobDb initialisation. start := s.clock.Now() if err := s.initialise(ctx); err != nil { return err } - log.Infof("JobDb initialised in %s", s.clock.Since(start)) + ctx.Log.Infof("JobDb initialised in %s", s.clock.Since(start)) ticker := s.clock.NewTicker(s.cyclePeriod) prevLeaderToken := InvalidLeaderToken() for { select { case <-ctx.Done(): - log.Infof("context cancelled; returning.") + ctx.Log.Infof("context cancelled; returning.") return ctx.Err() case <-ticker.C(): start := s.clock.Now() leaderToken := s.leaderController.GetToken() fullUpdate := false - log.Infof("received leaderToken; leader status is %t", leaderToken.leader) + ctx.Log.Infof("received leaderToken; leader status is %t", leaderToken.leader) // If we are becoming leader then we must ensure we have caught up to all Pulsar messages if leaderToken.leader && leaderToken != prevLeaderToken { - log.Infof("becoming leader") - syncContext, cancel := context.WithTimeout(ctx, 5*time.Minute) + ctx.Log.Infof("becoming leader") + syncContext, cancel := armadacontext.WithTimeout(ctx, 5*time.Minute) err := s.ensureDbUpToDate(syncContext, 1*time.Second) if err != nil { - log.WithError(err).Error("could not become leader") + logging.WithStacktrace(ctx.Log, err).Error("could not become leader") leaderToken = InvalidLeaderToken() } else { fullUpdate = true @@ -169,7 +165,7 @@ func (s *Scheduler) Run(ctx context.Context) error { result, err := s.cycle(ctx, fullUpdate, leaderToken, shouldSchedule) if err != nil { - logging.WithStacktrace(log, err).Error("scheduling cycle failure") + logging.WithStacktrace(ctx.Log, err).Error("scheduling cycle failure") leaderToken = InvalidLeaderToken() } @@ -181,10 +177,10 @@ func (s *Scheduler) Run(ctx context.Context) error { // Only the leader token does real scheduling rounds. s.metrics.ReportScheduleCycleTime(cycleTime) s.metrics.ReportSchedulerResult(result) - log.Infof("scheduling cycle completed in %s", cycleTime) + ctx.Log.Infof("scheduling cycle completed in %s", cycleTime) } else { s.metrics.ReportReconcileCycleTime(cycleTime) - log.Infof("reconciliation cycle completed in %s", cycleTime) + ctx.Log.Infof("reconciliation cycle completed in %s", cycleTime) } prevLeaderToken = leaderToken @@ -198,11 +194,9 @@ func (s *Scheduler) Run(ctx context.Context) error { // cycle is a single iteration of the main scheduling loop. // If updateAll is true, we generate events from all jobs in the jobDb. // Otherwise, we only generate events from jobs updated since the last cycle. -func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken LeaderToken, shouldSchedule bool) (overallSchedulerResult SchedulerResult, err error) { +func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToken LeaderToken, shouldSchedule bool) (overallSchedulerResult SchedulerResult, err error) { overallSchedulerResult = SchedulerResult{EmptyResult: true} - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "cycle") // Update job state. updatedJobs, err := s.syncState(ctx) if err != nil { @@ -244,7 +238,7 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade } var resultEvents []*armadaevents.EventSequence - resultEvents, err = s.eventsFromSchedulerResult(txn, result) + resultEvents, err = s.eventsFromSchedulerResult(result) if err != nil { return } @@ -262,22 +256,19 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade if err = s.publisher.PublishMessages(ctx, events, isLeader); err != nil { return } - log.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) + ctx.Log.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) txn.Commit() return } // syncState updates jobs in jobDb to match state in postgres and returns all updated jobs. -func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "syncState") - +func (s *Scheduler) syncState(ctx *armadacontext.Context) ([]*jobdb.Job, error) { start := s.clock.Now() updatedJobs, updatedRuns, err := s.jobRepository.FetchJobUpdates(ctx, s.jobsSerial, s.runsSerial) if err != nil { return nil, err } - log.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) + ctx.Log.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) txn := s.jobDb.WriteTxn() defer txn.Abort() @@ -321,7 +312,7 @@ func (s *Scheduler) syncState(ctx context.Context) ([]*jobdb.Job, error) { // If the job is nil or terminal at this point then it cannot be active. // In this case we can ignore the run. if job == nil || job.InTerminalState() { - log.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) + ctx.Log.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) continue } } @@ -391,7 +382,7 @@ func (s *Scheduler) addNodeAntiAffinitiesForAttemptedRunsIfSchedulable(job *jobd } // eventsFromSchedulerResult generates necessary EventSequences from the provided SchedulerResult. -func (s *Scheduler) eventsFromSchedulerResult(txn *jobdb.Txn, result *SchedulerResult) ([]*armadaevents.EventSequence, error) { +func (s *Scheduler) eventsFromSchedulerResult(result *SchedulerResult) ([]*armadaevents.EventSequence, error) { return EventsFromSchedulerResult(result, s.clock.Now()) } @@ -507,7 +498,7 @@ func AppendEventSequencesFromScheduledJobs(eventSequences []*armadaevents.EventS // generateUpdateMessages generates EventSequences representing the state changes on updated jobs // If there are no state changes then an empty slice will be returned -func (s *Scheduler) generateUpdateMessages(ctx context.Context, updatedJobs []*jobdb.Job, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { +func (s *Scheduler) generateUpdateMessages(ctx *armadacontext.Context, updatedJobs []*jobdb.Job, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { failedRunIds := make([]uuid.UUID, 0, len(updatedJobs)) for _, job := range updatedJobs { run := job.LatestRun() @@ -708,10 +699,7 @@ func (s *Scheduler) generateUpdateMessagesFromJob(job *jobdb.Job, jobRunErrors m // expireJobsIfNecessary removes any jobs from the JobDb which are running on stale executors. // It also generates an EventSequence for each job, indicating that both the run and the job has failed // Note that this is different behaviour from the old scheduler which would allow expired jobs to be rerun -func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "expireJobsIfNecessary") - +func (s *Scheduler) expireJobsIfNecessary(ctx *armadacontext.Context, txn *jobdb.Txn) ([]*armadaevents.EventSequence, error) { heartbeatTimes, err := s.executorRepository.GetLastUpdateTimes(ctx) if err != nil { return nil, err @@ -726,14 +714,14 @@ func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ( // has been completely removed for executor, heartbeat := range heartbeatTimes { if heartbeat.Before(cutOff) { - log.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) + ctx.Log.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) staleExecutors[executor] = true } } // All clusters have had a heartbeat recently. No need to expire any jobs if len(staleExecutors) == 0 { - log.Infof("No stale executors found. No jobs need to be expired") + ctx.Log.Infof("No stale executors found. No jobs need to be expired") return nil, nil } @@ -750,7 +738,7 @@ func (s *Scheduler) expireJobsIfNecessary(ctx context.Context, txn *jobdb.Txn) ( run := job.LatestRun() if run != nil && !job.Queued() && staleExecutors[run.Executor()] { - log.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) + ctx.Log.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) jobsToUpdate = append(jobsToUpdate, job.WithQueued(false).WithFailed(true).WithUpdatedRun(run.WithFailed(true))) jobId, err := armadaevents.ProtoUuidFromUlidString(job.Id()) @@ -808,16 +796,14 @@ func (s *Scheduler) now() *time.Time { // initialise builds the initial job db based on the current database state // right now this is quite dim and loads the entire database but in the future // we should be able to make it load active jobs/runs only -func (s *Scheduler) initialise(ctx context.Context) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "initialise") +func (s *Scheduler) initialise(ctx *armadacontext.Context) error { for { select { case <-ctx.Done(): return nil default: if _, err := s.syncState(ctx); err != nil { - log.WithError(err).Error("failed to initialise; trying again in 1 second") + ctx.Log.WithError(err).Error("failed to initialise; trying again in 1 second") time.Sleep(1 * time.Second) } else { // Initialisation succeeded. @@ -830,10 +816,7 @@ func (s *Scheduler) initialise(ctx context.Context) error { // ensureDbUpToDate blocks until that the database state contains all Pulsar messages sent *before* this // function was called. This is achieved firstly by publishing messages to Pulsar and then polling the // database until all messages have been written. -func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Duration) error { - log := ctxlogrus.Extract(ctx) - log = log.WithField("function", "ensureDbUpToDate") - +func (s *Scheduler) ensureDbUpToDate(ctx *armadacontext.Context, pollInterval time.Duration) error { groupId := uuid.New() var numSent uint32 var err error @@ -847,7 +830,7 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura default: numSent, err = s.publisher.PublishMarkers(ctx, groupId) if err != nil { - log.WithError(err).Error("Error sending marker messages to pulsar") + ctx.Log.WithError(err).Error("Error sending marker messages to pulsar") s.clock.Sleep(pollInterval) } else { messagesSent = true @@ -863,13 +846,13 @@ func (s *Scheduler) ensureDbUpToDate(ctx context.Context, pollInterval time.Dura default: numReceived, err := s.jobRepository.CountReceivedPartitions(ctx, groupId) if err != nil { - log.WithError(err).Error("Error querying the database or marker messages") + ctx.Log.WithError(err).Error("Error querying the database or marker messages") } if numSent == numReceived { - log.Infof("Successfully ensured that database state is up to date") + ctx.Log.Infof("Successfully ensured that database state is up to date") return nil } - log.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) + ctx.Log.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) s.clock.Sleep(pollInterval) } } diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 1db7ad4ae8f..584f4552d42 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "sync" "testing" @@ -15,6 +14,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/common/util" @@ -527,7 +527,7 @@ func TestScheduler_TestCycle(t *testing.T) { txn.Commit() // run a scheduler cycle - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) _, err = sched.cycle(ctx, false, sched.leaderController.GetToken(), true) if tc.fetchError || tc.publishError || tc.scheduleError { assert.Error(t, err) @@ -684,7 +684,7 @@ func TestRun(t *testing.T) { sched.clock = testClock - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := armadacontext.WithCancel(armadacontext.Background()) //nolint:errcheck go sched.Run(ctx) @@ -861,7 +861,7 @@ func TestScheduler_TestSyncState(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() // Test objects @@ -943,31 +943,31 @@ type testJobRepository struct { numReceivedPartitions uint32 } -func (t *testJobRepository) FindInactiveRuns(ctx context.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { +func (t *testJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) { // TODO implement me panic("implement me") } -func (t *testJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { +func (t *testJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { // TODO implement me panic("implement me") } -func (t *testJobRepository) FetchJobUpdates(ctx context.Context, jobSerial int64, jobRunSerial int64) ([]database.Job, []database.Run, error) { +func (t *testJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobSerial int64, jobRunSerial int64) ([]database.Job, []database.Run, error) { if t.shouldError { return nil, nil, errors.New("error fetchiung job updates") } return t.updatedJobs, t.updatedRuns, nil } -func (t *testJobRepository) FetchJobRunErrors(ctx context.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { +func (t *testJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) { if t.shouldError { return nil, errors.New("error fetching job run errors") } return t.errors, nil } -func (t *testJobRepository) CountReceivedPartitions(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (t *testJobRepository) CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { if t.shouldError { return 0, errors.New("error counting received partitions") } @@ -979,18 +979,18 @@ type testExecutorRepository struct { shouldError bool } -func (t testExecutorRepository) GetExecutors(ctx context.Context) ([]*schedulerobjects.Executor, error) { +func (t testExecutorRepository) GetExecutors(ctx *armadacontext.Context) ([]*schedulerobjects.Executor, error) { panic("not implemented") } -func (t testExecutorRepository) GetLastUpdateTimes(ctx context.Context) (map[string]time.Time, error) { +func (t testExecutorRepository) GetLastUpdateTimes(ctx *armadacontext.Context) (map[string]time.Time, error) { if t.shouldError { return nil, errors.New("error getting last update time") } return t.updateTimes, nil } -func (t testExecutorRepository) StoreExecutor(ctx context.Context, executor *schedulerobjects.Executor) error { +func (t testExecutorRepository) StoreExecutor(ctx *armadacontext.Context, executor *schedulerobjects.Executor) error { panic("not implemented") } @@ -1001,7 +1001,7 @@ type testSchedulingAlgo struct { shouldError bool } -func (t *testSchedulingAlgo) Schedule(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) { +func (t *testSchedulingAlgo) Schedule(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) { t.numberOfScheduleCalls++ if t.shouldError { return nil, errors.New("error scheduling jobs") @@ -1049,7 +1049,7 @@ type testPublisher struct { shouldError bool } -func (t *testPublisher) PublishMessages(ctx context.Context, events []*armadaevents.EventSequence, _ func() bool) error { +func (t *testPublisher) PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, _ func() bool) error { t.events = events if t.shouldError { return errors.New("Error when publishing") @@ -1061,7 +1061,7 @@ func (t *testPublisher) Reset() { t.events = nil } -func (t *testPublisher) PublishMarkers(ctx context.Context, groupId uuid.UUID) (uint32, error) { +func (t *testPublisher) PublishMarkers(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error) { return 100, nil } diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index ef742c3dc24..9ba1302c920 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -10,17 +10,16 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/go-redis/redis" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" "github.com/armadaproject/armada/internal/common" "github.com/armadaproject/armada/internal/common/app" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/auth" dbcommon "github.com/armadaproject/armada/internal/common/database" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" @@ -35,9 +34,7 @@ import ( // Run sets up a Scheduler application and runs it until a SIGTERM is received func Run(config schedulerconfig.Configuration) error { - g, ctx := errgroup.WithContext(app.CreateContextWithShutdown()) - logrusLogger := log.NewEntry(log.StandardLogger()) - ctx = ctxlogrus.ToContext(ctx, logrusLogger) + g, ctx := armadacontext.ErrGroup(app.CreateContextWithShutdown()) ////////////////////////////////////////////////////////////////////////// // Health Checks diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 11f6b667f96..a1865d1601b 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -7,7 +7,6 @@ import ( "github.com/benbjohnson/immutable" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -16,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/logging" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" @@ -34,7 +34,7 @@ import ( type SchedulingAlgo interface { // Schedule should assign jobs to nodes. // Any jobs that are scheduled should be marked as such in the JobDb using the transaction provided. - Schedule(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) + Schedule(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*SchedulerResult, error) } // FairSchedulingAlgo is a SchedulingAlgo based on PreemptingQueueScheduler. @@ -88,12 +88,10 @@ func NewFairSchedulingAlgo( // It maintains state of which executors it has considered already and may take multiple Schedule() calls to consider all executors if scheduling is slow. // Newly leased jobs are updated as such in the jobDb using the transaction provided and are also returned to the caller. func (l *FairSchedulingAlgo) Schedule( - ctx context.Context, + ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb, ) (*SchedulerResult, error) { - log := ctxlogrus.Extract(ctx) - overallSchedulerResult := &SchedulerResult{ NodeIdByJobId: make(map[string]string), SchedulingContexts: make([]*schedulercontext.SchedulingContext, 0, 0), @@ -101,11 +99,11 @@ func (l *FairSchedulingAlgo) Schedule( // Exit immediately if scheduling is disabled. if l.schedulingConfig.DisableScheduling { - log.Info("skipping scheduling - scheduling disabled") + ctx.Log.Info("skipping scheduling - scheduling disabled") return overallSchedulerResult, nil } - ctxWithTimeout, cancel := context.WithTimeout(ctx, l.maxSchedulingDuration) + ctxWithTimeout, cancel := armadacontext.WithTimeout(ctx, l.maxSchedulingDuration) defer cancel() fsctx, err := l.newFairSchedulingAlgoContext(ctx, txn, jobDb) @@ -123,7 +121,7 @@ func (l *FairSchedulingAlgo) Schedule( select { case <-ctxWithTimeout.Done(): // We've reached the scheduling time limit; exit gracefully. - log.Info("ending scheduling round early as we have hit the maximum scheduling duration") + ctx.Log.Info("ending scheduling round early as we have hit the maximum scheduling duration") return overallSchedulerResult, nil default: } @@ -142,7 +140,7 @@ func (l *FairSchedulingAlgo) Schedule( // Assume pool and minimumJobSize are consistent within the group. pool := executorGroup[0].Pool minimumJobSize := executorGroup[0].MinimumJobSize - log.Infof( + ctx.Log.Infof( "scheduling on executor group %s with capacity %s", executorGroupLabel, fsctx.totalCapacityByPool[pool].CompactString(), ) @@ -158,14 +156,14 @@ func (l *FairSchedulingAlgo) Schedule( // add the executorGroupLabel back to l.executorGroupsToSchedule such that we try it again next time, // and exit gracefully. l.executorGroupsToSchedule = append(l.executorGroupsToSchedule, executorGroupLabel) - log.Info("stopped scheduling early as we have hit the maximum scheduling duration") + ctx.Log.Info("stopped scheduling early as we have hit the maximum scheduling duration") break } else if err != nil { return nil, err } if l.schedulingContextRepository != nil { if err := l.schedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(log, err).Error("failed to add scheduling context") + logging.WithStacktrace(ctx.Log, err).Error("failed to add scheduling context") } } @@ -239,7 +237,7 @@ type fairSchedulingAlgoContext struct { jobDb *jobdb.JobDb } -func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx context.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*fairSchedulingAlgoContext, error) { +func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Context, txn *jobdb.Txn, jobDb *jobdb.JobDb) (*fairSchedulingAlgoContext, error) { executors, err := l.executorRepository.GetExecutors(ctx) if err != nil { return nil, err @@ -330,7 +328,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx context.Context, t // scheduleOnExecutors schedules jobs on a specified set of executors. func (l *FairSchedulingAlgo) scheduleOnExecutors( - ctx context.Context, + ctx *armadacontext.Context, fsctx *fairSchedulingAlgoContext, pool string, minimumJobSize schedulerobjects.ResourceList, @@ -556,17 +554,16 @@ func (l *FairSchedulingAlgo) filterStaleExecutors(executors []*schedulerobjects. // // TODO: Let's also check that jobs are on the right nodes. func (l *FairSchedulingAlgo) filterLaggingExecutors( - ctx context.Context, + ctx *armadacontext.Context, executors []*schedulerobjects.Executor, leasedJobsByExecutor map[string][]*jobdb.Job, ) []*schedulerobjects.Executor { - log := ctxlogrus.Extract(ctx) activeExecutors := make([]*schedulerobjects.Executor, 0, len(executors)) for _, executor := range executors { leasedJobs := leasedJobsByExecutor[executor.Id] executorRuns, err := executor.AllRuns() if err != nil { - logging.WithStacktrace(log, err).Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) + logging.WithStacktrace(ctx.Log, err).Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) continue } executorRunIds := make(map[uuid.UUID]bool, len(executorRuns)) @@ -585,7 +582,7 @@ func (l *FairSchedulingAlgo) filterLaggingExecutors( if numUnacknowledgedJobs <= l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor { activeExecutors = append(activeExecutors, executor) } else { - log.Warnf( + ctx.Log.Warnf( "%d unacknowledged jobs on executor %s exceeds limit of %d; executor will not be considered for scheduling", numUnacknowledgedJobs, executor.Id, l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor, ) diff --git a/internal/scheduler/scheduling_algo_test.go b/internal/scheduler/scheduling_algo_test.go index 6cb6a276f6a..2bf766ecd40 100644 --- a/internal/scheduler/scheduling_algo_test.go +++ b/internal/scheduler/scheduling_algo_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "math" "testing" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/scheduler/database" "github.com/armadaproject/armada/internal/scheduler/jobdb" @@ -330,9 +330,8 @@ func TestSchedule(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx := testfixtures.ContextWithDefaultLogger(context.Background()) timeout := 5 * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/simulator/simulator.go b/internal/scheduler/simulator/simulator.go index 639f617c275..94fa8989b84 100644 --- a/internal/scheduler/simulator/simulator.go +++ b/internal/scheduler/simulator/simulator.go @@ -3,20 +3,17 @@ package simulator import ( "bytes" "container/heap" - "context" - fmt "fmt" + "fmt" "os" "path/filepath" "strings" "time" "github.com/caarlos0/log" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/mattn/go-zglob" "github.com/oklog/ulid" "github.com/pkg/errors" "github.com/renstrom/shortuuid" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -25,6 +22,7 @@ import ( "k8s.io/apimachinery/pkg/util/yaml" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" commonconfig "github.com/armadaproject/armada/internal/common/config" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/util" @@ -34,7 +32,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/fairness" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/nodedb" - schedulerobjects "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/internal/scheduleringester" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -472,7 +470,7 @@ func (s *Simulator) handleScheduleEvent() error { if s.schedulingConfig.EnableNewPreemptionStrategy { sch.EnableNewPreemptionStrategy() } - ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New())) + ctx := armadacontext.Background() result, err := sch.Schedule(ctx) if err != nil { return err @@ -775,7 +773,7 @@ func (s *Simulator) handleJobRunPreempted(txn *jobdb.Txn, e *armadaevents.JobRun return true, nil } -// func (a *App) TestPattern(ctx context.Context, pattern string) (*TestSuiteReport, error) { +// func (a *App) TestPattern(ctx *context.Context, pattern string) (*TestSuiteReport, error) { // testSpecs, err := TestSpecsFromPattern(pattern) // if err != nil { // return nil, err diff --git a/internal/scheduler/submitcheck.go b/internal/scheduler/submitcheck.go index 6221e2611e9..bf79e0eb317 100644 --- a/internal/scheduler/submitcheck.go +++ b/internal/scheduler/submitcheck.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "fmt" "strings" "sync" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -84,7 +84,7 @@ func NewSubmitChecker( } } -func (srv *SubmitChecker) Run(ctx context.Context) error { +func (srv *SubmitChecker) Run(ctx *armadacontext.Context) error { srv.updateExecutors(ctx) ticker := time.NewTicker(srv.ExecutorUpdateFrequency) @@ -98,7 +98,7 @@ func (srv *SubmitChecker) Run(ctx context.Context) error { } } -func (srv *SubmitChecker) updateExecutors(ctx context.Context) { +func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) { executors, err := srv.executorRepository.GetExecutors(ctx) if err != nil { log.WithError(err).Error("Error fetching executors") diff --git a/internal/scheduler/submitcheck_test.go b/internal/scheduler/submitcheck_test.go index a95f3d9abbf..87be5674bf8 100644 --- a/internal/scheduler/submitcheck_test.go +++ b/internal/scheduler/submitcheck_test.go @@ -1,7 +1,6 @@ package scheduler import ( - "context" "testing" "time" @@ -14,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/jobdb" schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks" @@ -72,7 +72,7 @@ func TestSubmitChecker_CheckJobDbJobs(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) @@ -170,7 +170,7 @@ func TestSubmitChecker_TestCheckApiJobs(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() ctrl := gomock.NewController(t) diff --git a/internal/scheduler/testfixtures/testfixtures.go b/internal/scheduler/testfixtures/testfixtures.go index 0acda7d60a6..e73d246c74a 100644 --- a/internal/scheduler/testfixtures/testfixtures.go +++ b/internal/scheduler/testfixtures/testfixtures.go @@ -2,16 +2,13 @@ package testfixtures // This file contains test fixtures to be used throughout the tests for this package. import ( - "context" "fmt" "math" "sync/atomic" "time" "github.com/google/uuid" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/oklog/ulid" - "github.com/sirupsen/logrus" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -82,10 +79,6 @@ func Repeat[T any](v T, n int) []T { return rv } -func ContextWithDefaultLogger(ctx context.Context) context.Context { - return ctxlogrus.ToContext(ctx, logrus.NewEntry(logrus.New())) -} - func TestSchedulingConfig() configuration.SchedulingConfig { return configuration.SchedulingConfig{ ResourceScarcity: map[string]float64{"cpu": 1}, diff --git a/internal/scheduleringester/instructions.go b/internal/scheduleringester/instructions.go index 429ab2d9112..4a8bc70fd51 100644 --- a/internal/scheduleringester/instructions.go +++ b/internal/scheduleringester/instructions.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "time" "github.com/gogo/protobuf/proto" @@ -10,6 +9,7 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/ingest/metrics" @@ -46,7 +46,7 @@ func NewInstructionConverter( } } -func (c *InstructionConverter) Convert(_ context.Context, sequencesWithIds *ingest.EventSequencesWithIds) *DbOperationsWithMessageIds { +func (c *InstructionConverter) Convert(_ *armadacontext.Context, sequencesWithIds *ingest.EventSequencesWithIds) *DbOperationsWithMessageIds { operations := make([]DbOperation, 0) for _, es := range sequencesWithIds.EventSequences { for _, op := range c.dbOperationsFromEventSequence(es) { diff --git a/internal/scheduleringester/schedulerdb.go b/internal/scheduleringester/schedulerdb.go index e1ce855504b..058f0f4778b 100644 --- a/internal/scheduleringester/schedulerdb.go +++ b/internal/scheduleringester/schedulerdb.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "time" "github.com/google/uuid" @@ -10,6 +9,7 @@ import ( "github.com/pkg/errors" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/ingest/metrics" @@ -45,14 +45,14 @@ func NewSchedulerDb( // Store persists all operations in the database. // This function retires until it either succeeds or encounters a terminal error. // This function locks the postgres table to avoid write conflicts; see acquireLock() for details. -func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithMessageIds) error { +func (s *SchedulerDb) Store(ctx *armadacontext.Context, instructions *DbOperationsWithMessageIds) error { return ingest.WithRetry(func() (bool, error) { err := pgx.BeginTxFunc(ctx, s.db, pgx.TxOptions{ IsoLevel: pgx.ReadCommitted, AccessMode: pgx.ReadWrite, DeferrableMode: pgx.Deferrable, }, func(tx pgx.Tx) error { - lockCtx, cancel := context.WithTimeout(ctx, s.lockTimeout) + lockCtx, cancel := armadacontext.WithTimeout(ctx, s.lockTimeout) defer cancel() // The lock is released automatically on transaction rollback/commit. if err := s.acquireLock(lockCtx, tx); err != nil { @@ -78,7 +78,7 @@ func (s *SchedulerDb) Store(ctx context.Context, instructions *DbOperationsWithM // rows with sequence numbers smaller than those already written. // // The scheduler relies on these sequence numbers to only fetch new or updated rows in each update cycle. -func (s *SchedulerDb) acquireLock(ctx context.Context, tx pgx.Tx) error { +func (s *SchedulerDb) acquireLock(ctx *armadacontext.Context, tx pgx.Tx) error { const lockId = 8741339439634283896 if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", lockId); err != nil { return errors.Wrapf(err, "could not obtain lock") @@ -86,7 +86,7 @@ func (s *SchedulerDb) acquireLock(ctx context.Context, tx pgx.Tx) error { return nil } -func (s *SchedulerDb) WriteDbOp(ctx context.Context, tx pgx.Tx, op DbOperation) error { +func (s *SchedulerDb) WriteDbOp(ctx *armadacontext.Context, tx pgx.Tx, op DbOperation) error { queries := schedulerdb.New(tx) switch o := op.(type) { case InsertJobs: @@ -274,7 +274,7 @@ func (s *SchedulerDb) WriteDbOp(ctx context.Context, tx pgx.Tx, op DbOperation) return nil } -func execBatch(ctx context.Context, tx pgx.Tx, batch *pgx.Batch) error { +func execBatch(ctx *armadacontext.Context, tx pgx.Tx, batch *pgx.Batch) error { result := tx.SendBatch(ctx, batch) for i := 0; i < batch.Len(); i++ { _, err := result.Exec() diff --git a/internal/scheduleringester/schedulerdb_test.go b/internal/scheduleringester/schedulerdb_test.go index 8317e421aff..873885c369e 100644 --- a/internal/scheduleringester/schedulerdb_test.go +++ b/internal/scheduleringester/schedulerdb_test.go @@ -1,7 +1,6 @@ package scheduleringester import ( - "context" "testing" "time" @@ -14,6 +13,7 @@ import ( "golang.org/x/exp/constraints" "golang.org/x/exp/maps" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/ingest/metrics" "github.com/armadaproject/armada/internal/common/util" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" @@ -312,7 +312,7 @@ func addDefaultValues(op DbOperation) DbOperation { } func assertOpSuccess(t *testing.T, schedulerDb *SchedulerDb, serials map[string]int64, op DbOperation) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 10*time.Second) defer cancel() // Apply the op to the database. @@ -329,7 +329,7 @@ func assertOpSuccess(t *testing.T, schedulerDb *SchedulerDb, serials map[string] // Read back the state from the db to compare. queries := schedulerdb.New(schedulerDb.db) - selectNewJobs := func(ctx context.Context, serial int64) ([]schedulerdb.Job, error) { + selectNewJobs := func(ctx *armadacontext.Context, serial int64) ([]schedulerdb.Job, error) { return queries.SelectNewJobs(ctx, schedulerdb.SelectNewJobsParams{Serial: serial, Limit: 1000}) } switch expected := op.(type) { @@ -645,7 +645,7 @@ func TestStore(t *testing.T) { runId: &JobRunDetails{queue: testQueueName, dbRun: &schedulerdb.Run{JobID: jobId, RunID: runId}}, }, } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) defer cancel() err := schedulerdb.WithTestDb(func(q *schedulerdb.Queries, db *pgxpool.Pool) error { schedulerDb := NewSchedulerDb(db, metrics.NewMetrics("test"), time.Second, time.Second, 10*time.Second) From 2457da7ee9dc78a2a4dc91235a191cc8c62d557d Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Mon, 11 Sep 2023 10:26:54 -0600 Subject: [PATCH 3/9] Bump armada airflow operator to version 0.5.4 (#2961) * Bump armada airflow operator to version 0.5.4 Signed-off-by: Rich Scott * Regenerate Airflow Operator Markdown doc. Signed-off-by: Rich Scott * Fix regenerated Airflow doc error. Signed-off-by: Rich Scott * Pin versions of all modules, especially around docs generation. Signed-off-by: Rich Scott * Regenerate Airflow docs using Python 3.10 Signed-off-by: Rich Scott --------- Signed-off-by: Rich Scott --- docs/python_airflow_operator.md | 2 +- third_party/airflow/pyproject.toml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 1d820856344..c74a464751d 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -281,7 +281,7 @@ Runs the trigger. Meant to be called by an airflow triggerer process. #### serialize() -Returns the information needed to reconstruct this Trigger. +Return the information needed to reconstruct this Trigger. * **Returns** diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index bd9814cc10c..df506fc400c 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "armada_airflow" -version = "0.5.3" +version = "0.5.4" description = "Armada Airflow Operator" requires-python = ">=3.7" # Note(JayF): This dependency value is not suitable for release. Whatever @@ -9,10 +9,10 @@ requires-python = ">=3.7" # extremely difficult. dependencies = [ "armada-client", - "apache-airflow>=2.6.3", - "grpcio>=1.46.3", - "grpcio-tools>=1.46.3", - "types-protobuf>=3.19.22" + "apache-airflow==2.7.1", + "grpcio==1.58.0", + "grpcio-tools==1.58.0", + "types-protobuf==4.24.0.1" ] authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } @@ -20,7 +20,7 @@ readme = "README.md" [project.optional-dependencies] format = ["black==23.7.0", "flake8==6.1.0", "pylint==2.17.5"] -test = ["pytest==7.3.1", "coverage>=6.5.0", "pytest-asyncio==0.21.1"] +test = ["pytest==7.3.1", "coverage==7.3.1", "pytest-asyncio==0.21.1"] # note(JayF): sphinx-jekyll-builder was broken by sphinx-markdown-builder 0.6 -- so pin to 0.5.5 docs = ["sphinx==7.1.2", "sphinx-jekyll-builder==0.3.0", "sphinx-toolbox==3.2.0b1", "sphinx-markdown-builder==0.5.5"] From 2c9e4971d9fe26b16487f4b7bfea3652a7587e12 Mon Sep 17 00:00:00 2001 From: Mohamed Abdelfatah <39927413+Mo-Fatah@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:41:07 +0300 Subject: [PATCH 4/9] Magefile: Clean all Makefile refernces (#2957) * tiny naming change * clean all make refernces Signed-off-by: mohamed --------- Signed-off-by: mohamed --- .github/workflows/test.yml | 2 +- client/python/CONTRIBUTING.md | 2 +- client/python/README.md | 2 +- client/python/docs/README.md | 4 ++-- docs/developer/manual-localdev.md | 2 +- magefiles/linting.go | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b0ea22a381..887c3658836 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -214,7 +214,7 @@ jobs: echo -e "### Git status" >> $GITHUB_STEP_SUMMARY if [[ "$changed" -gt 0 ]]; then - echo -e "Generated proto files are out of date. Please run 'make proto' and commit the changes." >> $GITHUB_STEP_SUMMARY + echo -e "Generated proto files are out of date. Please run 'mage proto' and commit the changes." >> $GITHUB_STEP_SUMMARY git status -s -uno >> $GITHUB_STEP_SUMMARY diff --git a/client/python/CONTRIBUTING.md b/client/python/CONTRIBUTING.md index ca3c0f1f90d..ff015d4284e 100644 --- a/client/python/CONTRIBUTING.md +++ b/client/python/CONTRIBUTING.md @@ -26,7 +26,7 @@ workflow for contributing. First time contributors can follow the guide below to Unlike most python projects, the Armada python client contains a large quantity of generated code. This code must be generated in order to compile and develop against the client. -From the root of the repository, run `make python`. This will generate python code needed to build +From the root of the repository, run `mage buildPython`. This will generate python code needed to build and use the client. This command needs to be re-run anytime an API change is committed (e.g. a change to a `*.proto` file). diff --git a/client/python/README.md b/client/python/README.md index 92ed96b26b8..ea4f1409fb2 100644 --- a/client/python/README.md +++ b/client/python/README.md @@ -26,5 +26,5 @@ Before beginning, ensure you have: - Network access to fetch docker images and go dependencies. To generate all needed code, and install the python client: -1) From the root of the repository, run `make python` +1) From the root of the repository, run `mage buildPython` 2) Install the client using `pip install client/python`. It's strongly recommended you do this inside a virtualenv. diff --git a/client/python/docs/README.md b/client/python/docs/README.md index 056327c87ae..d8a7abfe1a0 100644 --- a/client/python/docs/README.md +++ b/client/python/docs/README.md @@ -9,13 +9,13 @@ Usage Easy way: - Ensure all protobufs files needed for the client are generated by running - `make python` from the repository root. + `mage buildPython` from the repository root. - `tox -e docs` will create a valid virtual environment and use it to generate documentation. The generated files will be placed under `build/jekyll/*.md`. Manual way: - Ensure all protobufs files needed for the client are generated by running - `make python` from the repository root. + `mage buildPython` from the repository root. - Create a virtual environment containing all the deps listed in `tox.ini` under `[testenv:docs]`. - Run `poetry install -v` from inside `client/python` to install the client diff --git a/docs/developer/manual-localdev.md b/docs/developer/manual-localdev.md index 236995857c7..65c19e2faad 100644 --- a/docs/developer/manual-localdev.md +++ b/docs/developer/manual-localdev.md @@ -28,7 +28,7 @@ mage BootstrapTools # Compile .pb.go files from .proto files # (only necessary after changing a .proto file). mage proto -make dotnet +mage dotnet # Build the Docker images containing all Armada components. # Only the main "bundle" is needed for quickly testing Armada. diff --git a/magefiles/linting.go b/magefiles/linting.go index bc7094cebff..e301850e93c 100644 --- a/magefiles/linting.go +++ b/magefiles/linting.go @@ -63,7 +63,7 @@ func LintFix() error { } // Linting Check -func CheckLint() error { +func LintCheck() error { mg.Deps(golangciLintCheck) cmd, err := go_TEST_CMD() if err != nil { From ba1973fad1280e82c6fac47e03c9f46ef2bd40c2 Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Tue, 12 Sep 2023 15:19:18 -0600 Subject: [PATCH 5/9] Revert to previous unpinned airflow version spec. (#2967) * Revert to previous unpinned airflow version spec. Signed-off-by: Rich Scott * Increment armada-airflow module version. Signed-off-by: Rich Scott --------- Signed-off-by: Rich Scott --- third_party/airflow/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index df506fc400c..aa8296d46cb 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "armada_airflow" -version = "0.5.4" +version = "0.5.5" description = "Armada Airflow Operator" requires-python = ">=3.7" # Note(JayF): This dependency value is not suitable for release. Whatever @@ -9,7 +9,7 @@ requires-python = ">=3.7" # extremely difficult. dependencies = [ "armada-client", - "apache-airflow==2.7.1", + "apache-airflow>=2.6.3", "grpcio==1.58.0", "grpcio-tools==1.58.0", "types-protobuf==4.24.0.1" From edc142610011b432849a05ad8077c879aee4789f Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Wed, 13 Sep 2023 17:38:54 +0100 Subject: [PATCH 6/9] ArmadaContext.Log Improvements (#2965) * log error * context log * context log * add cycle id * typo * lint * refactor armadacontext to implement a FieldLogger --------- Co-authored-by: Chris Martin --- cmd/scheduler/cmd/prune_database.go | 4 +- internal/armada/server/lease.go | 12 ++--- internal/armada/server/submit_from_log.go | 8 +-- .../common/armadacontext/armada_context.go | 38 +++++++------- .../armadacontext/armada_context_test.go | 8 +-- internal/common/logging/stacktrace.go | 6 +-- internal/scheduler/api.go | 4 +- internal/scheduler/database/db_pruner.go | 15 +++--- internal/scheduler/database/util.go | 3 +- internal/scheduler/leader.go | 8 +-- internal/scheduler/metrics.go | 18 ++++--- .../scheduler/preempting_queue_scheduler.go | 14 +++--- internal/scheduler/publisher.go | 10 ++-- internal/scheduler/scheduler.go | 50 ++++++++++--------- internal/scheduler/scheduler_metrics.go | 34 ++++++------- internal/scheduler/schedulerapp.go | 32 ++++++------ internal/scheduler/scheduling_algo.go | 16 +++--- internal/scheduler/submitcheck.go | 14 ++++-- 18 files changed, 156 insertions(+), 138 deletions(-) diff --git a/cmd/scheduler/cmd/prune_database.go b/cmd/scheduler/cmd/prune_database.go index 3b2250d1661..4ed7aee426e 100644 --- a/cmd/scheduler/cmd/prune_database.go +++ b/cmd/scheduler/cmd/prune_database.go @@ -1,13 +1,13 @@ package cmd import ( - "context" "time" "github.com/pkg/errors" "github.com/spf13/cobra" "k8s.io/apimachinery/pkg/util/clock" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" ) @@ -57,7 +57,7 @@ func pruneDatabase(cmd *cobra.Command, _ []string) error { return errors.WithMessagef(err, "Failed to connect to database") } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), timeout) defer cancel() return schedulerdb.PruneDb(ctx, db, batchSize, expireAfter, clock.RealClock{}) } diff --git a/internal/armada/server/lease.go b/internal/armada/server/lease.go index 9a776d0e15f..2b9d7bb753e 100644 --- a/internal/armada/server/lease.go +++ b/internal/armada/server/lease.go @@ -344,7 +344,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str lastSeen, ) if err != nil { - logging.WithStacktrace(ctx.Log, err).Warnf( + logging.WithStacktrace(ctx, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetClusterId(), ) continue @@ -566,7 +566,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str if q.SchedulingContextRepository != nil { sctx.ClearJobSpecs() if err := q.SchedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(ctx.Log, err).Error("failed to store scheduling context") + logging.WithStacktrace(ctx, err).Error("failed to store scheduling context") } } @@ -641,7 +641,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str jobIdsToDelete := util.Map(jobsToDelete, func(job *api.Job) string { return job.Id }) log.Infof("deleting preempted jobs: %v", jobIdsToDelete) if deletionResult, err := q.jobRepository.DeleteJobs(jobsToDelete); err != nil { - logging.WithStacktrace(ctx.Log, err).Error("failed to delete preempted jobs from Redis") + logging.WithStacktrace(ctx, err).Error("failed to delete preempted jobs from Redis") } else { deleteErrorByJobId := armadamaps.MapKeys(deletionResult, func(job *api.Job) string { return job.Id }) for jobId := range preemptedApiJobsById { @@ -704,7 +704,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str } } if err := q.usageRepository.UpdateClusterQueueResourceUsage(req.ClusterId, currentExecutorReport); err != nil { - logging.WithStacktrace(ctx.Log, err).Errorf("failed to update cluster usage") + logging.WithStacktrace(ctx, err).Errorf("failed to update cluster usage") } allocatedByQueueAndPriorityClassForPool = q.aggregateAllocationAcrossExecutor(reportsByExecutor, req.Pool) @@ -728,7 +728,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx, err).Warnf("failed to set node id selector on job %s: node with id %s not found", apiJob.Id, nodeId) continue } v := node.Labels[q.schedulingConfig.Preemption.NodeIdLabel] @@ -764,7 +764,7 @@ func (q *AggregatedQueueServer) getJobs(ctx *armadacontext.Context, req *api.Str } node, err := nodeDb.GetNode(nodeId) if err != nil { - logging.WithStacktrace(ctx.Log, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) + logging.WithStacktrace(ctx, err).Warnf("failed to set node name on job %s: node with id %s not found", apiJob.Id, nodeId) continue } podSpec.NodeName = node.Name diff --git a/internal/armada/server/submit_from_log.go b/internal/armada/server/submit_from_log.go index 90b5ece3553..995e9785d5b 100644 --- a/internal/armada/server/submit_from_log.go +++ b/internal/armada/server/submit_from_log.go @@ -125,12 +125,12 @@ func (srv *SubmitFromLog) Run(ctx *armadacontext.Context) error { sequence, err := eventutil.UnmarshalEventSequence(ctxWithLogger, msg.Payload()) if err != nil { srv.ack(ctx, msg) - logging.WithStacktrace(ctxWithLogger.Log, err).Warnf("processing message failed; ignoring") + logging.WithStacktrace(ctxWithLogger, err).Warnf("processing message failed; ignoring") numErrored++ break } - ctxWithLogger.Log.WithField("numEvents", len(sequence.Events)).Info("processing sequence") + ctxWithLogger.WithField("numEvents", len(sequence.Events)).Info("processing sequence") // TODO: Improve retry logic. srv.ProcessSequence(ctxWithLogger, sequence) srv.ack(ctx, msg) @@ -155,11 +155,11 @@ func (srv *SubmitFromLog) ProcessSequence(ctx *armadacontext.Context, sequence * for i < len(sequence.Events) && time.Since(lastProgress) < timeout { j, err := srv.ProcessSubSequence(ctx, i, sequence) if err != nil { - logging.WithStacktrace(ctx.Log, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") + logging.WithStacktrace(ctx, err).WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Warnf("processing subsequence failed; ignoring") } if j == i { - ctx.Log.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") + ctx.WithFields(logrus.Fields{"lowerIndex": i, "upperIndex": j}).Info("made no progress") // We should only get here if a transient error occurs. // Sleep for a bit before retrying. diff --git a/internal/common/armadacontext/armada_context.go b/internal/common/armadacontext/armada_context.go index a6985ee5df7..0e41a66a1e4 100644 --- a/internal/common/armadacontext/armada_context.go +++ b/internal/common/armadacontext/armada_context.go @@ -13,22 +13,22 @@ import ( // while retaining type-safety type Context struct { context.Context - Log *logrus.Entry + logrus.FieldLogger } // Background creates an empty context with a default logger. It is analogous to context.Background() func Background() *Context { return &Context{ - Context: context.Background(), - Log: logrus.NewEntry(logrus.New()), + Context: context.Background(), + FieldLogger: logrus.NewEntry(logrus.New()), } } // TODO creates an empty context with a default logger. It is analogous to context.TODO() func TODO() *Context { return &Context{ - Context: context.TODO(), - Log: logrus.NewEntry(logrus.New()), + Context: context.TODO(), + FieldLogger: logrus.NewEntry(logrus.New()), } } @@ -42,8 +42,8 @@ func FromGrpcCtx(ctx context.Context) *Context { // New returns an armada context that encapsulates both a go context and a logger func New(ctx context.Context, log *logrus.Entry) *Context { return &Context{ - Context: ctx, - Log: log, + Context: ctx, + FieldLogger: log, } } @@ -51,8 +51,8 @@ func New(ctx context.Context, log *logrus.Entry) *Context { func WithCancel(parent *Context) (*Context, context.CancelFunc) { c, cancel := context.WithCancel(parent.Context) return &Context{ - Context: c, - Log: parent.Log, + Context: c, + FieldLogger: parent.FieldLogger, }, cancel } @@ -61,8 +61,8 @@ func WithCancel(parent *Context) (*Context, context.CancelFunc) { func WithDeadline(parent *Context, d time.Time) (*Context, context.CancelFunc) { c, cancel := context.WithDeadline(parent.Context, d) return &Context{ - Context: c, - Log: parent.Log, + Context: c, + FieldLogger: parent.FieldLogger, }, cancel } @@ -74,16 +74,16 @@ func WithTimeout(parent *Context, timeout time.Duration) (*Context, context.Canc // WithLogField returns a copy of parent with the supplied key-value added to the logger func WithLogField(parent *Context, key string, val interface{}) *Context { return &Context{ - Context: parent.Context, - Log: parent.Log.WithField(key, val), + Context: parent.Context, + FieldLogger: parent.FieldLogger.WithField(key, val), } } // WithLogFields returns a copy of parent with the supplied key-values added to the logger func WithLogFields(parent *Context, fields logrus.Fields) *Context { return &Context{ - Context: parent.Context, - Log: parent.Log.WithFields(fields), + Context: parent.Context, + FieldLogger: parent.FieldLogger.WithFields(fields), } } @@ -91,8 +91,8 @@ func WithLogFields(parent *Context, fields logrus.Fields) *Context { // val. It is analogous to context.WithValue() func WithValue(parent *Context, key, val any) *Context { return &Context{ - Context: context.WithValue(parent, key, val), - Log: parent.Log, + Context: context.WithValue(parent, key, val), + FieldLogger: parent.FieldLogger, } } @@ -101,7 +101,7 @@ func WithValue(parent *Context, key, val any) *Context { func ErrGroup(ctx *Context) (*errgroup.Group, *Context) { group, goctx := errgroup.WithContext(ctx) return group, &Context{ - Context: goctx, - Log: ctx.Log, + Context: goctx, + FieldLogger: ctx.FieldLogger, } } diff --git a/internal/common/armadacontext/armada_context_test.go b/internal/common/armadacontext/armada_context_test.go index a98d7b611df..4cda401c1b1 100644 --- a/internal/common/armadacontext/armada_context_test.go +++ b/internal/common/armadacontext/armada_context_test.go @@ -15,7 +15,7 @@ var defaultLogger = logrus.WithField("foo", "bar") func TestNew(t *testing.T) { ctx := New(context.Background(), defaultLogger) - require.Equal(t, defaultLogger, ctx.Log) + require.Equal(t, defaultLogger, ctx.FieldLogger) require.Equal(t, context.Background(), ctx.Context) } @@ -23,7 +23,7 @@ func TestFromGrpcContext(t *testing.T) { grpcCtx := ctxlogrus.ToContext(context.Background(), defaultLogger) ctx := FromGrpcCtx(grpcCtx) require.Equal(t, grpcCtx, ctx.Context) - require.Equal(t, defaultLogger, ctx.Log) + require.Equal(t, defaultLogger, ctx.FieldLogger) } func TestBackground(t *testing.T) { @@ -39,13 +39,13 @@ func TestTODO(t *testing.T) { func TestWithLogField(t *testing.T) { ctx := WithLogField(Background(), "fish", "chips") require.Equal(t, context.Background(), ctx.Context) - require.Equal(t, logrus.Fields{"fish": "chips"}, ctx.Log.Data) + require.Equal(t, logrus.Fields{"fish": "chips"}, ctx.FieldLogger.(*logrus.Entry).Data) } func TestWithLogFields(t *testing.T) { ctx := WithLogFields(Background(), logrus.Fields{"fish": "chips", "salt": "pepper"}) require.Equal(t, context.Background(), ctx.Context) - require.Equal(t, logrus.Fields{"fish": "chips", "salt": "pepper"}, ctx.Log.Data) + require.Equal(t, logrus.Fields{"fish": "chips", "salt": "pepper"}, ctx.FieldLogger.(*logrus.Entry).Data) } func TestWithTimeout(t *testing.T) { diff --git a/internal/common/logging/stacktrace.go b/internal/common/logging/stacktrace.go index cdcf4aef525..7d546915b31 100644 --- a/internal/common/logging/stacktrace.go +++ b/internal/common/logging/stacktrace.go @@ -10,9 +10,9 @@ type stackTracer interface { StackTrace() errors.StackTrace } -// WithStacktrace returns a new logrus.Entry obtained by adding error information and, if available, a stack trace -// as fields to the provided logrus.Entry. -func WithStacktrace(logger *logrus.Entry, err error) *logrus.Entry { +// WithStacktrace returns a new logrus.FieldLogger obtained by adding error information and, if available, a stack trace +// as fields to the provided logrus.FieldLogger. +func WithStacktrace(logger logrus.FieldLogger, err error) logrus.FieldLogger { logger = logger.WithError(err) if stackErr, ok := err.(stackTracer); ok { return logger.WithField("stacktrace", stackErr.StackTrace()) diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index a31eba85f5e..533abc4b728 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -103,7 +103,7 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns if err != nil { return err } - ctx.Log.Infof( + ctx.Infof( "executor currently has %d job runs; sending %d cancellations and %d new runs", len(requestRuns), len(runsToCancel), len(newRuns), ) @@ -226,7 +226,7 @@ func (srv *ExecutorApi) executorFromLeaseRequest(ctx *armadacontext.Context, req now := srv.clock.Now().UTC() for _, nodeInfo := range req.Nodes { if node, err := api.NewNodeFromNodeInfo(nodeInfo, req.ExecutorId, srv.allowedPriorities, now); err != nil { - logging.WithStacktrace(ctx.Log, err).Warnf( + logging.WithStacktrace(ctx, err).Warnf( "skipping node %s from executor %s", nodeInfo.GetName(), req.GetExecutorId(), ) } else { diff --git a/internal/scheduler/database/db_pruner.go b/internal/scheduler/database/db_pruner.go index 9ea8075a40d..8da7dd7935d 100644 --- a/internal/scheduler/database/db_pruner.go +++ b/internal/scheduler/database/db_pruner.go @@ -1,13 +1,13 @@ package database import ( - ctx "context" "time" "github.com/jackc/pgx/v5" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" + + "github.com/armadaproject/armada/internal/common/armadacontext" ) // PruneDb removes completed jobs (and related runs and errors) from the database if their `lastUpdateTime` @@ -15,7 +15,7 @@ import ( // Jobs are deleted in batches across transactions. This means that if this job fails midway through, it still // may have deleted some jobs. // The function will run until the supplied context is cancelled. -func PruneDb(ctx ctx.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion time.Duration, clock clock.Clock) error { +func PruneDb(ctx *armadacontext.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion time.Duration, clock clock.Clock) error { start := time.Now() cutOffTime := clock.Now().Add(-keepAfterCompletion) @@ -40,11 +40,11 @@ func PruneDb(ctx ctx.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion return errors.WithStack(err) } if totalJobsToDelete == 0 { - log.Infof("Found no jobs to be deleted. Exiting") + ctx.Infof("Found no jobs to be deleted. Exiting") return nil } - log.Infof("Found %d jobs to be deleted", totalJobsToDelete) + ctx.Infof("Found %d jobs to be deleted", totalJobsToDelete) // create temp table to hold a batch of results _, err = db.Exec(ctx, "CREATE TEMP TABLE batch (job_id TEXT);") @@ -93,9 +93,10 @@ func PruneDb(ctx ctx.Context, db *pgx.Conn, batchLimit int, keepAfterCompletion taken := time.Now().Sub(batchStart) jobsDeleted += batchSize - log.Infof("Deleted %d jobs in %s. Deleted %d jobs out of %d", batchSize, taken, jobsDeleted, totalJobsToDelete) + ctx. + Infof("Deleted %d jobs in %s. Deleted %d jobs out of %d", batchSize, taken, jobsDeleted, totalJobsToDelete) } taken := time.Now().Sub(start) - log.Infof("Deleted %d jobs in %s", jobsDeleted, taken) + ctx.Infof("Deleted %d jobs in %s", jobsDeleted, taken) return nil } diff --git a/internal/scheduler/database/util.go b/internal/scheduler/database/util.go index 618c32c8efb..af338ee3b42 100644 --- a/internal/scheduler/database/util.go +++ b/internal/scheduler/database/util.go @@ -6,7 +6,6 @@ import ( "time" "github.com/jackc/pgx/v5/pgxpool" - log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" @@ -25,7 +24,7 @@ func Migrate(ctx *armadacontext.Context, db database.Querier) error { if err != nil { return err } - log.Infof("Updated scheduler database in %s", time.Now().Sub(start)) + ctx.Infof("Updated scheduler database in %s", time.Now().Sub(start)) return nil } diff --git a/internal/scheduler/leader.go b/internal/scheduler/leader.go index 0482184a7a8..714cf243f52 100644 --- a/internal/scheduler/leader.go +++ b/internal/scheduler/leader.go @@ -145,7 +145,7 @@ func (lc *KubernetesLeaderController) Run(ctx *armadacontext.Context) error { return ctx.Err() default: lock := lc.getNewLock() - ctx.Log.Infof("attempting to become leader") + ctx.Infof("attempting to become leader") leaderelection.RunOrDie(ctx, leaderelection.LeaderElectionConfig{ Lock: lock, ReleaseOnCancel: true, @@ -154,14 +154,14 @@ func (lc *KubernetesLeaderController) Run(ctx *armadacontext.Context) error { RetryPeriod: lc.config.RetryPeriod, Callbacks: leaderelection.LeaderCallbacks{ OnStartedLeading: func(c context.Context) { - ctx.Log.Infof("I am now leader") + ctx.Infof("I am now leader") lc.token.Store(NewLeaderToken()) for _, listener := range lc.listeners { listener.onStartedLeading(ctx) } }, OnStoppedLeading: func() { - ctx.Log.Infof("I am no longer leader") + ctx.Infof("I am no longer leader") lc.token.Store(InvalidLeaderToken()) for _, listener := range lc.listeners { listener.onStoppedLeading() @@ -174,7 +174,7 @@ func (lc *KubernetesLeaderController) Run(ctx *armadacontext.Context) error { }, }, }) - ctx.Log.Infof("leader election round finished") + ctx.Infof("leader election round finished") } } } diff --git a/internal/scheduler/metrics.go b/internal/scheduler/metrics.go index a7fb2f08c78..168295ff91f 100644 --- a/internal/scheduler/metrics.go +++ b/internal/scheduler/metrics.go @@ -7,10 +7,10 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/logging" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/scheduler/database" @@ -78,16 +78,18 @@ func NewMetricsCollector( // Run enters s a loop which updates the metrics every refreshPeriod until the supplied context is cancelled func (c *MetricsCollector) Run(ctx *armadacontext.Context) error { ticker := c.clock.NewTicker(c.refreshPeriod) - log.Infof("Will update metrics every %s", c.refreshPeriod) + ctx.Infof("Will update metrics every %s", c.refreshPeriod) for { select { case <-ctx.Done(): - log.Debugf("Context cancelled, returning..") + ctx.Debugf("Context cancelled, returning..") return nil case <-ticker.C(): err := c.refresh(ctx) if err != nil { - log.WithError(err).Warnf("error refreshing metrics state") + logging. + WithStacktrace(ctx, err). + Warnf("error refreshing metrics state") } } } @@ -109,7 +111,7 @@ func (c *MetricsCollector) Collect(metrics chan<- prometheus.Metric) { } func (c *MetricsCollector) refresh(ctx *armadacontext.Context) error { - log.Debugf("Refreshing prometheus metrics") + ctx.Debugf("Refreshing prometheus metrics") start := time.Now() queueMetrics, err := c.updateQueueMetrics(ctx) if err != nil { @@ -121,7 +123,7 @@ func (c *MetricsCollector) refresh(ctx *armadacontext.Context) error { } allMetrics := append(queueMetrics, clusterMetrics...) c.state.Store(allMetrics) - log.Debugf("Refreshed prometheus metrics in %s", time.Since(start)) + ctx.Debugf("Refreshed prometheus metrics in %s", time.Since(start)) return nil } @@ -154,7 +156,7 @@ func (c *MetricsCollector) updateQueueMetrics(ctx *armadacontext.Context) ([]pro } qs, ok := provider.queueStates[job.Queue()] if !ok { - log.Warnf("job %s is in queue %s, but this queue does not exist; skipping", job.Id(), job.Queue()) + ctx.Warnf("job %s is in queue %s, but this queue does not exist; skipping", job.Id(), job.Queue()) continue } @@ -181,7 +183,7 @@ func (c *MetricsCollector) updateQueueMetrics(ctx *armadacontext.Context) ([]pro timeInState = currentTime.Sub(time.Unix(0, run.Created())) recorder = qs.runningJobRecorder } else { - log.Warnf("Job %s is marked as leased but has no runs", job.Id()) + ctx.Warnf("Job %s is marked as leased but has no runs", job.Id()) } recorder.RecordJobRuntime(pool, priorityClass, timeInState) recorder.RecordResources(pool, priorityClass, jobResources) diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index fd0c0d9e079..adfbe4d86b5 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -129,11 +129,11 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*Sche sch.nodeEvictionProbability, func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - ctx.Log.Errorf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Errorf("can't evict job %s: annotations not initialised", job.GetId()) return false } if job.GetNodeSelector() == nil { - ctx.Log.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) + ctx.Errorf("can't evict job %s: nodeSelector not initialised", job.GetId()) return false } if qctx, ok := sch.schedulingContext.QueueSchedulingContexts[job.GetQueue()]; ok { @@ -241,10 +241,10 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*Sche return nil, err } if s := JobsSummary(preemptedJobs); s != "" { - ctx.Log.Infof("preempting running jobs; %s", s) + ctx.Infof("preempting running jobs; %s", s) } if s := JobsSummary(scheduledJobs); s != "" { - ctx.Log.Infof("scheduling new jobs; %s", s) + ctx.Infof("scheduling new jobs; %s", s) } if sch.enableAssertions { err := sch.assertions( @@ -805,7 +805,7 @@ func NewOversubscribedEvictor( }, jobFilter: func(ctx *armadacontext.Context, job interfaces.LegacySchedulerJob) bool { if job.GetAnnotations() == nil { - ctx.Log.Warnf("can't evict job %s: annotations not initialised", job.GetId()) + ctx.Warnf("can't evict job %s: annotations not initialised", job.GetId()) return false } priorityClassName := job.GetPriorityClassName() @@ -884,7 +884,7 @@ func defaultPostEvictFunc(ctx *armadacontext.Context, job interfaces.LegacySched // Add annotation indicating to the scheduler this this job was evicted. annotations := job.GetAnnotations() if annotations == nil { - ctx.Log.Errorf("error evicting job %s: annotations not initialised", job.GetId()) + ctx.Errorf("error evicting job %s: annotations not initialised", job.GetId()) } else { annotations[schedulerconfig.IsEvictedAnnotation] = "true" } @@ -892,7 +892,7 @@ func defaultPostEvictFunc(ctx *armadacontext.Context, job interfaces.LegacySched // Add node selector ensuring this job is only re-scheduled onto the node it was evicted from. nodeSelector := job.GetNodeSelector() if nodeSelector == nil { - ctx.Log.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) + ctx.Errorf("error evicting job %s: nodeSelector not initialised", job.GetId()) } else { nodeSelector[schedulerconfig.NodeIdLabel] = node.Id } diff --git a/internal/scheduler/publisher.go b/internal/scheduler/publisher.go index 0b308141961..598a00fc755 100644 --- a/internal/scheduler/publisher.go +++ b/internal/scheduler/publisher.go @@ -10,10 +10,10 @@ import ( "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/eventutil" + "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/schedulers" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -103,13 +103,15 @@ func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, events []* // Send messages if shouldPublish() { - log.Debugf("Am leader so will publish") + ctx.Debugf("Am leader so will publish") sendCtx, cancel := armadacontext.WithTimeout(ctx, p.pulsarSendTimeout) errored := false for _, msg := range msgs { p.producer.SendAsync(sendCtx, msg, func(_ pulsar.MessageID, _ *pulsar.ProducerMessage, err error) { if err != nil { - log.WithError(err).Error("error sending message to Pulsar") + logging. + WithStacktrace(ctx, err). + Error("error sending message to Pulsar") errored = true } wg.Done() @@ -121,7 +123,7 @@ func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, events []* return errors.New("One or more messages failed to send to Pulsar") } } else { - log.Debugf("No longer leader so not publishing") + ctx.Debugf("No longer leader so not publishing") } return nil } diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index ccc4d998ff5..4084ae726a8 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -7,6 +7,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/pkg/errors" + "github.com/renstrom/shortuuid" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" @@ -116,36 +117,37 @@ func NewScheduler( // Run enters the scheduling loop, which will continue until ctx is cancelled. func (s *Scheduler) Run(ctx *armadacontext.Context) error { - ctx.Log.Infof("starting scheduler with cycle time %s", s.cyclePeriod) - defer ctx.Log.Info("scheduler stopped") + ctx.Infof("starting scheduler with cycle time %s", s.cyclePeriod) + defer ctx.Info("scheduler stopped") // JobDb initialisation. start := s.clock.Now() if err := s.initialise(ctx); err != nil { return err } - ctx.Log.Infof("JobDb initialised in %s", s.clock.Since(start)) + ctx.Infof("JobDb initialised in %s", s.clock.Since(start)) ticker := s.clock.NewTicker(s.cyclePeriod) prevLeaderToken := InvalidLeaderToken() for { select { case <-ctx.Done(): - ctx.Log.Infof("context cancelled; returning.") + ctx.Infof("context cancelled; returning.") return ctx.Err() case <-ticker.C(): start := s.clock.Now() + ctx := armadacontext.WithLogField(ctx, "cycleId", shortuuid.New()) leaderToken := s.leaderController.GetToken() fullUpdate := false - ctx.Log.Infof("received leaderToken; leader status is %t", leaderToken.leader) + ctx.Infof("received leaderToken; leader status is %t", leaderToken.leader) // If we are becoming leader then we must ensure we have caught up to all Pulsar messages if leaderToken.leader && leaderToken != prevLeaderToken { - ctx.Log.Infof("becoming leader") + ctx.Infof("becoming leader") syncContext, cancel := armadacontext.WithTimeout(ctx, 5*time.Minute) err := s.ensureDbUpToDate(syncContext, 1*time.Second) if err != nil { - logging.WithStacktrace(ctx.Log, err).Error("could not become leader") + logging.WithStacktrace(ctx, err).Error("could not become leader") leaderToken = InvalidLeaderToken() } else { fullUpdate = true @@ -165,7 +167,7 @@ func (s *Scheduler) Run(ctx *armadacontext.Context) error { result, err := s.cycle(ctx, fullUpdate, leaderToken, shouldSchedule) if err != nil { - logging.WithStacktrace(ctx.Log, err).Error("scheduling cycle failure") + logging.WithStacktrace(ctx, err).Error("scheduling cycle failure") leaderToken = InvalidLeaderToken() } @@ -174,13 +176,13 @@ func (s *Scheduler) Run(ctx *armadacontext.Context) error { s.metrics.ResetGaugeMetrics() if shouldSchedule && leaderToken.leader { - // Only the leader token does real scheduling rounds. + // Only the leader does real scheduling rounds. s.metrics.ReportScheduleCycleTime(cycleTime) - s.metrics.ReportSchedulerResult(result) - ctx.Log.Infof("scheduling cycle completed in %s", cycleTime) + s.metrics.ReportSchedulerResult(ctx, result) + ctx.Infof("scheduling cycle completed in %s", cycleTime) } else { s.metrics.ReportReconcileCycleTime(cycleTime) - ctx.Log.Infof("reconciliation cycle completed in %s", cycleTime) + ctx.Infof("reconciliation cycle completed in %s", cycleTime) } prevLeaderToken = leaderToken @@ -256,7 +258,7 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke if err = s.publisher.PublishMessages(ctx, events, isLeader); err != nil { return } - ctx.Log.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) + ctx.Infof("published %d events to pulsar in %s", len(events), s.clock.Since(start)) txn.Commit() return } @@ -268,7 +270,7 @@ func (s *Scheduler) syncState(ctx *armadacontext.Context) ([]*jobdb.Job, error) if err != nil { return nil, err } - ctx.Log.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) + ctx.Infof("received %d updated jobs and %d updated job runs in %s", len(updatedJobs), len(updatedRuns), s.clock.Since(start)) txn := s.jobDb.WriteTxn() defer txn.Abort() @@ -312,7 +314,7 @@ func (s *Scheduler) syncState(ctx *armadacontext.Context) ([]*jobdb.Job, error) // If the job is nil or terminal at this point then it cannot be active. // In this case we can ignore the run. if job == nil || job.InTerminalState() { - ctx.Log.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) + ctx.Debugf("job %s is not active; ignoring update for run %s", jobId, dbRun.RunID) continue } } @@ -714,14 +716,14 @@ func (s *Scheduler) expireJobsIfNecessary(ctx *armadacontext.Context, txn *jobdb // has been completely removed for executor, heartbeat := range heartbeatTimes { if heartbeat.Before(cutOff) { - ctx.Log.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) + ctx.Warnf("Executor %s has not reported a hearbeart since %v. Will expire all jobs running on this executor", executor, heartbeat) staleExecutors[executor] = true } } // All clusters have had a heartbeat recently. No need to expire any jobs if len(staleExecutors) == 0 { - ctx.Log.Infof("No stale executors found. No jobs need to be expired") + ctx.Infof("No stale executors found. No jobs need to be expired") return nil, nil } @@ -738,7 +740,7 @@ func (s *Scheduler) expireJobsIfNecessary(ctx *armadacontext.Context, txn *jobdb run := job.LatestRun() if run != nil && !job.Queued() && staleExecutors[run.Executor()] { - ctx.Log.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) + ctx.Warnf("Cancelling job %s as it is running on lost executor %s", job.Id(), run.Executor()) jobsToUpdate = append(jobsToUpdate, job.WithQueued(false).WithFailed(true).WithUpdatedRun(run.WithFailed(true))) jobId, err := armadaevents.ProtoUuidFromUlidString(job.Id()) @@ -803,7 +805,7 @@ func (s *Scheduler) initialise(ctx *armadacontext.Context) error { return nil default: if _, err := s.syncState(ctx); err != nil { - ctx.Log.WithError(err).Error("failed to initialise; trying again in 1 second") + logging.WithStacktrace(ctx, err).Error("failed to initialise; trying again in 1 second") time.Sleep(1 * time.Second) } else { // Initialisation succeeded. @@ -830,7 +832,7 @@ func (s *Scheduler) ensureDbUpToDate(ctx *armadacontext.Context, pollInterval ti default: numSent, err = s.publisher.PublishMarkers(ctx, groupId) if err != nil { - ctx.Log.WithError(err).Error("Error sending marker messages to pulsar") + logging.WithStacktrace(ctx, err).Error("Error sending marker messages to pulsar") s.clock.Sleep(pollInterval) } else { messagesSent = true @@ -846,13 +848,15 @@ func (s *Scheduler) ensureDbUpToDate(ctx *armadacontext.Context, pollInterval ti default: numReceived, err := s.jobRepository.CountReceivedPartitions(ctx, groupId) if err != nil { - ctx.Log.WithError(err).Error("Error querying the database or marker messages") + logging. + WithStacktrace(ctx, err). + Error("Error querying the database or marker messages") } if numSent == numReceived { - ctx.Log.Infof("Successfully ensured that database state is up to date") + ctx.Infof("Successfully ensured that database state is up to date") return nil } - ctx.Log.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) + ctx.Infof("Recevied %d partitions, still waiting on %d", numReceived, numSent-numReceived) s.clock.Sleep(pollInterval) } } diff --git a/internal/scheduler/scheduler_metrics.go b/internal/scheduler/scheduler_metrics.go index 25840fae841..3ba197ebeba 100644 --- a/internal/scheduler/scheduler_metrics.go +++ b/internal/scheduler/scheduler_metrics.go @@ -4,9 +4,9 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/armada/configuration" + "github.com/armadaproject/armada/internal/common/armadacontext" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" "github.com/armadaproject/armada/internal/scheduler/interfaces" ) @@ -157,29 +157,29 @@ func (metrics *SchedulerMetrics) ReportReconcileCycleTime(cycleTime time.Duratio metrics.reconcileCycleTime.Observe(float64(cycleTime.Milliseconds())) } -func (metrics *SchedulerMetrics) ReportSchedulerResult(result SchedulerResult) { +func (metrics *SchedulerMetrics) ReportSchedulerResult(ctx *armadacontext.Context, result SchedulerResult) { if result.EmptyResult { return // TODO: Add logging or maybe place to add failure metric? } // Report the total scheduled jobs (possibly we can get these out of contexts?) - metrics.reportScheduledJobs(result.ScheduledJobs) - metrics.reportPreemptedJobs(result.PreemptedJobs) + metrics.reportScheduledJobs(ctx, result.ScheduledJobs) + metrics.reportPreemptedJobs(ctx, result.PreemptedJobs) // TODO: When more metrics are added, consider consolidating into a single loop over the data. // Report the number of considered jobs. - metrics.reportNumberOfJobsConsidered(result.SchedulingContexts) - metrics.reportQueueShares(result.SchedulingContexts) + metrics.reportNumberOfJobsConsidered(ctx, result.SchedulingContexts) + metrics.reportQueueShares(ctx, result.SchedulingContexts) } -func (metrics *SchedulerMetrics) reportScheduledJobs(scheduledJobs []interfaces.LegacySchedulerJob) { +func (metrics *SchedulerMetrics) reportScheduledJobs(ctx *armadacontext.Context, scheduledJobs []interfaces.LegacySchedulerJob) { jobAggregates := aggregateJobs(scheduledJobs) - observeJobAggregates(metrics.scheduledJobsPerQueue, jobAggregates) + observeJobAggregates(ctx, metrics.scheduledJobsPerQueue, jobAggregates) } -func (metrics *SchedulerMetrics) reportPreemptedJobs(preemptedJobs []interfaces.LegacySchedulerJob) { +func (metrics *SchedulerMetrics) reportPreemptedJobs(ctx *armadacontext.Context, preemptedJobs []interfaces.LegacySchedulerJob) { jobAggregates := aggregateJobs(preemptedJobs) - observeJobAggregates(metrics.preemptedJobsPerQueue, jobAggregates) + observeJobAggregates(ctx, metrics.preemptedJobsPerQueue, jobAggregates) } type collectionKey struct { @@ -200,7 +200,7 @@ func aggregateJobs[S ~[]E, E interfaces.LegacySchedulerJob](scheduledJobs S) map } // observeJobAggregates reports a set of job aggregates to a given CounterVec by queue and priorityClass. -func observeJobAggregates(metric prometheus.CounterVec, jobAggregates map[collectionKey]int) { +func observeJobAggregates(ctx *armadacontext.Context, metric prometheus.CounterVec, jobAggregates map[collectionKey]int) { for key, count := range jobAggregates { queue := key.queue priorityClassName := key.priorityClass @@ -209,14 +209,14 @@ func observeJobAggregates(metric prometheus.CounterVec, jobAggregates map[collec if err != nil { // A metric failure isn't reason to kill the programme. - log.Errorf("error reteriving considered jobs observer for queue %s, priorityClass %s", queue, priorityClassName) + ctx.Errorf("error reteriving considered jobs observer for queue %s, priorityClass %s", queue, priorityClassName) } else { observer.Add(float64(count)) } } } -func (metrics *SchedulerMetrics) reportNumberOfJobsConsidered(schedulingContexts []*schedulercontext.SchedulingContext) { +func (metrics *SchedulerMetrics) reportNumberOfJobsConsidered(ctx *armadacontext.Context, schedulingContexts []*schedulercontext.SchedulingContext) { for _, schedContext := range schedulingContexts { pool := schedContext.Pool for queue, queueContext := range schedContext.QueueSchedulingContexts { @@ -224,7 +224,7 @@ func (metrics *SchedulerMetrics) reportNumberOfJobsConsidered(schedulingContexts observer, err := metrics.consideredJobs.GetMetricWithLabelValues(queue, pool) if err != nil { - log.Errorf("error reteriving considered jobs observer for queue %s, pool %s", queue, pool) + ctx.Errorf("error reteriving considered jobs observer for queue %s, pool %s", queue, pool) } else { observer.Add(float64(count)) } @@ -232,7 +232,7 @@ func (metrics *SchedulerMetrics) reportNumberOfJobsConsidered(schedulingContexts } } -func (metrics *SchedulerMetrics) reportQueueShares(schedulingContexts []*schedulercontext.SchedulingContext) { +func (metrics *SchedulerMetrics) reportQueueShares(ctx *armadacontext.Context, schedulingContexts []*schedulercontext.SchedulingContext) { for _, schedContext := range schedulingContexts { totalCost := schedContext.TotalCost() totalWeight := schedContext.WeightSum @@ -243,7 +243,7 @@ func (metrics *SchedulerMetrics) reportQueueShares(schedulingContexts []*schedul observer, err := metrics.fairSharePerQueue.GetMetricWithLabelValues(queue, pool) if err != nil { - log.Errorf("error reteriving considered jobs observer for queue %s, pool %s", queue, pool) + ctx.Errorf("error retrieving considered jobs observer for queue %s, pool %s", queue, pool) } else { observer.Set(fairShare) } @@ -252,7 +252,7 @@ func (metrics *SchedulerMetrics) reportQueueShares(schedulingContexts []*schedul observer, err = metrics.actualSharePerQueue.GetMetricWithLabelValues(queue, pool) if err != nil { - log.Errorf("error reteriving considered jobs observer for queue %s, pool %s", queue, pool) + ctx.Errorf("error reteriving considered jobs observer for queue %s, pool %s", queue, pool) } else { observer.Set(actualShare) } diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index 9ba1302c920..c045591b175 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -12,7 +12,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" @@ -24,6 +23,7 @@ import ( dbcommon "github.com/armadaproject/armada/internal/common/database" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" + "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/pulsarutils" "github.com/armadaproject/armada/internal/common/stringinterner" schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" @@ -55,7 +55,7 @@ func Run(config schedulerconfig.Configuration) error { ////////////////////////////////////////////////////////////////////////// // Database setup (postgres and redis) ////////////////////////////////////////////////////////////////////////// - log.Infof("Setting up database connections") + ctx.Infof("Setting up database connections") db, err := dbcommon.OpenPgxPool(config.Postgres) if err != nil { return errors.WithMessage(err, "Error opening connection to postgres") @@ -68,7 +68,9 @@ func Run(config schedulerconfig.Configuration) error { defer func() { err := redisClient.Close() if err != nil { - log.WithError(errors.WithStack(err)).Warnf("Redis client didn't close down cleanly") + logging. + WithStacktrace(ctx, err). + Warnf("Redis client didn't close down cleanly") } }() queueRepository := database.NewLegacyQueueRepository(redisClient) @@ -77,7 +79,7 @@ func Run(config schedulerconfig.Configuration) error { ////////////////////////////////////////////////////////////////////////// // Pulsar ////////////////////////////////////////////////////////////////////////// - log.Infof("Setting up Pulsar connectivity") + ctx.Infof("Setting up Pulsar connectivity") pulsarClient, err := pulsarutils.NewPulsarClient(&config.Pulsar) if err != nil { return errors.WithMessage(err, "Error creating pulsar client") @@ -97,7 +99,7 @@ func Run(config schedulerconfig.Configuration) error { ////////////////////////////////////////////////////////////////////////// // Leader Election ////////////////////////////////////////////////////////////////////////// - leaderController, err := createLeaderController(config.Leader) + leaderController, err := createLeaderController(ctx, config.Leader) if err != nil { return errors.WithMessage(err, "error creating leader controller") } @@ -106,7 +108,7 @@ func Run(config schedulerconfig.Configuration) error { ////////////////////////////////////////////////////////////////////////// // Executor Api ////////////////////////////////////////////////////////////////////////// - log.Infof("Setting up executor api") + ctx.Infof("Setting up executor api") apiProducer, err := pulsarClient.CreateProducer(pulsar.ProducerOptions{ Name: fmt.Sprintf("armada-executor-api-%s", uuid.NewString()), CompressionType: config.Pulsar.CompressionType, @@ -144,7 +146,7 @@ func Run(config schedulerconfig.Configuration) error { } executorapi.RegisterExecutorApiServer(grpcServer, executorServer) services = append(services, func() error { - log.Infof("Executor api listening on %s", lis.Addr()) + ctx.Infof("Executor api listening on %s", lis.Addr()) return grpcServer.Serve(lis) }) services = append(services, grpcCommon.CreateShutdownHandler(ctx, 5*time.Second, grpcServer)) @@ -152,7 +154,7 @@ func Run(config schedulerconfig.Configuration) error { ////////////////////////////////////////////////////////////////////////// // Scheduling ////////////////////////////////////////////////////////////////////////// - log.Infof("setting up scheduling loop") + ctx.Infof("setting up scheduling loop") stringInterner, err := stringinterner.New(config.InternedStringsCacheSize) if err != nil { return errors.WithMessage(err, "error creating string interner") @@ -238,14 +240,14 @@ func Run(config schedulerconfig.Configuration) error { return g.Wait() } -func createLeaderController(config schedulerconfig.LeaderConfig) (LeaderController, error) { +func createLeaderController(ctx *armadacontext.Context, config schedulerconfig.LeaderConfig) (LeaderController, error) { switch mode := strings.ToLower(config.Mode); mode { case "standalone": - log.Infof("Scheduler will run in standalone mode") + ctx.Infof("Scheduler will run in standalone mode") return NewStandaloneLeaderController(), nil case "kubernetes": - log.Infof("Scheduler will run kubernetes mode") - clusterConfig, err := loadClusterConfig() + ctx.Infof("Scheduler will run kubernetes mode") + clusterConfig, err := loadClusterConfig(ctx) if err != nil { return nil, errors.Wrapf(err, "Error creating kubernetes client") } @@ -263,14 +265,14 @@ func createLeaderController(config schedulerconfig.LeaderConfig) (LeaderControll } } -func loadClusterConfig() (*rest.Config, error) { +func loadClusterConfig(ctx *armadacontext.Context) (*rest.Config, error) { config, err := rest.InClusterConfig() if err == rest.ErrNotInCluster { - log.Info("Running with default client configuration") + ctx.Info("Running with default client configuration") rules := clientcmd.NewDefaultClientConfigLoadingRules() overrides := &clientcmd.ConfigOverrides{} return clientcmd.NewNonInteractiveDeferredLoadingClientConfig(rules, overrides).ClientConfig() } - log.Info("Running with in cluster client configuration") + ctx.Info("Running with in cluster client configuration") return config, err } diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index a1865d1601b..563ab93f000 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -99,7 +99,7 @@ func (l *FairSchedulingAlgo) Schedule( // Exit immediately if scheduling is disabled. if l.schedulingConfig.DisableScheduling { - ctx.Log.Info("skipping scheduling - scheduling disabled") + ctx.Info("skipping scheduling - scheduling disabled") return overallSchedulerResult, nil } @@ -121,7 +121,7 @@ func (l *FairSchedulingAlgo) Schedule( select { case <-ctxWithTimeout.Done(): // We've reached the scheduling time limit; exit gracefully. - ctx.Log.Info("ending scheduling round early as we have hit the maximum scheduling duration") + ctx.Info("ending scheduling round early as we have hit the maximum scheduling duration") return overallSchedulerResult, nil default: } @@ -140,7 +140,7 @@ func (l *FairSchedulingAlgo) Schedule( // Assume pool and minimumJobSize are consistent within the group. pool := executorGroup[0].Pool minimumJobSize := executorGroup[0].MinimumJobSize - ctx.Log.Infof( + ctx.Infof( "scheduling on executor group %s with capacity %s", executorGroupLabel, fsctx.totalCapacityByPool[pool].CompactString(), ) @@ -156,14 +156,14 @@ func (l *FairSchedulingAlgo) Schedule( // add the executorGroupLabel back to l.executorGroupsToSchedule such that we try it again next time, // and exit gracefully. l.executorGroupsToSchedule = append(l.executorGroupsToSchedule, executorGroupLabel) - ctx.Log.Info("stopped scheduling early as we have hit the maximum scheduling duration") + ctx.Info("stopped scheduling early as we have hit the maximum scheduling duration") break } else if err != nil { return nil, err } if l.schedulingContextRepository != nil { if err := l.schedulingContextRepository.AddSchedulingContext(sctx); err != nil { - logging.WithStacktrace(ctx.Log, err).Error("failed to add scheduling context") + logging.WithStacktrace(ctx, err).Error("failed to add scheduling context") } } @@ -563,7 +563,9 @@ func (l *FairSchedulingAlgo) filterLaggingExecutors( leasedJobs := leasedJobsByExecutor[executor.Id] executorRuns, err := executor.AllRuns() if err != nil { - logging.WithStacktrace(ctx.Log, err).Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) + logging. + WithStacktrace(ctx, err). + Errorf("failed to retrieve runs for executor %s; will not be considered for scheduling", executor.Id) continue } executorRunIds := make(map[uuid.UUID]bool, len(executorRuns)) @@ -582,7 +584,7 @@ func (l *FairSchedulingAlgo) filterLaggingExecutors( if numUnacknowledgedJobs <= l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor { activeExecutors = append(activeExecutors, executor) } else { - ctx.Log.Warnf( + ctx.Warnf( "%d unacknowledged jobs on executor %s exceeds limit of %d; executor will not be considered for scheduling", numUnacknowledgedJobs, executor.Id, l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor, ) diff --git a/internal/scheduler/submitcheck.go b/internal/scheduler/submitcheck.go index bf79e0eb317..d390f3c5037 100644 --- a/internal/scheduler/submitcheck.go +++ b/internal/scheduler/submitcheck.go @@ -8,12 +8,12 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "k8s.io/apimachinery/pkg/util/clock" "github.com/armadaproject/armada/internal/armada/configuration" "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/logging" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -101,7 +101,9 @@ func (srv *SubmitChecker) Run(ctx *armadacontext.Context) error { func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) { executors, err := srv.executorRepository.GetExecutors(ctx) if err != nil { - log.WithError(err).Error("Error fetching executors") + logging. + WithStacktrace(ctx, err). + Error("Error fetching executors") return } for _, executor := range executors { @@ -114,10 +116,14 @@ func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) { } srv.mu.Unlock() if err != nil { - log.WithError(err).Errorf("Error constructing node db for executor %s", executor.Id) + logging. + WithStacktrace(ctx, err). + Errorf("Error constructing node db for executor %s", executor.Id) } } else { - log.WithError(err).Warnf("Error clearing nodedb for executor %s", executor.Id) + logging. + WithStacktrace(ctx, err). + Warnf("Error clearing nodedb for executor %s", executor.Id) } } From babce23f0e2102ec7b6ef360e2ffe92a3ab2bd5e Mon Sep 17 00:00:00 2001 From: Mohamed Abdelfatah <39927413+Mo-Fatah@users.noreply.github.com> Date: Fri, 15 Sep 2023 00:28:01 +0300 Subject: [PATCH 7/9] Run `on.push` only for master (#2968) * Run On Push only for master Signed-off-by: mohamed * remove not-workflows Signed-off-by: mohamed --------- Signed-off-by: mohamed --- .github/workflows/ci.yml | 4 +- .github/workflows/not-airflow-operator.yml | 47 ---------------------- .github/workflows/not-python-client.yml | 42 ------------------- 3 files changed, 2 insertions(+), 91 deletions(-) delete mode 100644 .github/workflows/not-airflow-operator.yml delete mode 100644 .github/workflows/not-python-client.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 125b18d3096..1b688b3731b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,10 +2,10 @@ name: CI on: push: + branches: + - master tags: - v* - branches-ignore: - - gh-pages pull_request: branches-ignore: - gh-pages diff --git a/.github/workflows/not-airflow-operator.yml b/.github/workflows/not-airflow-operator.yml deleted file mode 100644 index 298cb79c0fd..00000000000 --- a/.github/workflows/not-airflow-operator.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: Python Airflow Operator - -on: - push: - branches-ignore: - - master - paths-ignore: - - 'client/python/**' - - 'build/python-client/**' - - 'pkg/api/*.proto' - - '.github/workflows/python-client.yml' - - 'docs/python_armada_client.md' - - 'scripts/build-python-client.sh' - - 'third_party/airflow/**' - - 'build/airflow-operator/**' - - 'pkg/api/jobservice/*.proto' - - '.github/workflows/airflow-operator.yml' - - 'docs/python_airflow_operator.md' - - 'scripts/build-airflow-operator.sh' - - '.github/workflows/python-tests/*' - - pull_request: - branches-ignore: - - gh-pages - paths-ignore: - - 'client/python/**' - - 'build/python-client/**' - - 'pkg/api/*.proto' - - '.github/workflows/python-client.yml' - - 'docs/python_armada_client.md' - - 'scripts/build-python-client.sh' - - 'third_party/airflow/**' - - 'build/airflow-operator/**' - - 'pkg/api/jobservice/*.proto' - - '.github/workflows/airflow-operator.yml' - - 'docs/python_airflow_operator.md' - - 'scripts/build-airflow-operator.sh' - - '.github/workflows/python-tests/*' - -jobs: - airflow-tox: - strategy: - matrix: - go: [ '1.20' ] - runs-on: ubuntu-latest - steps: - - run: 'echo "No airflow operator code modified, not running airflow operator jobs"' diff --git a/.github/workflows/not-python-client.yml b/.github/workflows/not-python-client.yml deleted file mode 100644 index a7704606427..00000000000 --- a/.github/workflows/not-python-client.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Python Client - -on: - push: - branches-ignore: - - master - paths-ignore: - - 'client/python/**' - - 'build/python-client/**' - - 'pkg/api/*.proto' - - '.github/workflows/python-client.yml' - - 'docs/python_armada_client.md' - - 'scripts/build-python-client.sh' - - '.github/workflows/python-tests/*' - - pull_request: - branches-ignore: - - gh-pages - paths-ignore: - - 'client/python/**' - - 'build/python-client/**' - - 'pkg/api/*.proto' - - '.github/workflows/python-client.yml' - - 'docs/python_armada_client.md' - - 'scripts/build-python-client.sh' - - '.github/workflows/python-tests/*' - -jobs: - python-client-tox: - strategy: - matrix: - go: [ '1.20' ] - runs-on: ubuntu-latest - steps: - - run: 'echo "No python modified, not running python jobs"' - python-client-integration-tests: - strategy: - matrix: - go: [ '1.20' ] - runs-on: ubuntu-latest - steps: - - run: 'echo "No python modified, not running python jobs"' From 8eb7201d20d28d74e892cd7489366e9db0c31855 Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Mon, 18 Sep 2023 10:24:31 -0600 Subject: [PATCH 8/9] WIP: Airflow: fix undefined poll_interval in Deferrable Operator (#2975) * Airflow: handle poll_interval attr in ArmadaJobCompleteTrigger Fix incomplete handling of 'poll_interval' attribute in ArmadaJobCompleteTrigger, used by the Armada Deferrable Operator for Airflow. Signed-off-by: Rich Scott * Airflow - add unit test for armada deferrable operator Run much of the same tests for the deferrable operator as for the regular operator, plus test serialization. Also, update interval signifier in examples. A full test of the deferrable operator that verifies the trigger handling is still needed. Signed-off-by: Rich Scott --------- Signed-off-by: Rich Scott --- docs/python_airflow_operator.md | 28 ++- .../armada/operators/armada_deferrable.py | 39 +++- third_party/airflow/armada/operators/utils.py | 4 +- third_party/airflow/examples/big_armada.py | 2 +- .../tests/unit/test_airflow_operator_mock.py | 4 +- .../unit/test_armada_deferrable_operator.py | 171 ++++++++++++++++++ .../test_search_for_job_complete_asyncio.py | 5 + 7 files changed, 245 insertions(+), 8 deletions(-) create mode 100644 third_party/airflow/tests/unit/test_armada_deferrable_operator.py diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index c74a464751d..048667a2562 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -239,9 +239,27 @@ Reports the result of the job and returns. +#### serialize() +Get a serialized version of this object. + + +* **Returns** + + A dict of keyword arguments used when instantiating + + + +* **Return type** + + dict + + +this object. + + #### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) -### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name) +### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30) Bases: `BaseTrigger` An airflow trigger that monitors the job state of an armada job. @@ -269,6 +287,9 @@ Triggers when the job is complete. belongs. + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **Returns** @@ -664,7 +685,7 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED -### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, time_out_for_failure=7200) +### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200) Poll JobService cache asyncronously until you get a terminated event. A terminated event is SUCCEEDED, FAILED or CANCELLED @@ -689,6 +710,9 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED It is optional only for testing + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **time_out_for_failure** (*int*) – The amount of time a job can be in job_id_not_found before we decide it was a invalid job diff --git a/third_party/airflow/armada/operators/armada_deferrable.py b/third_party/airflow/armada/operators/armada_deferrable.py index 2f53a702228..f7aa1413637 100644 --- a/third_party/airflow/armada/operators/armada_deferrable.py +++ b/third_party/airflow/armada/operators/armada_deferrable.py @@ -103,6 +103,25 @@ def __init__( self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval + def serialize(self) -> dict: + """ + Get a serialized version of this object. + + :return: A dict of keyword arguments used when instantiating + this object. + """ + + return { + "task_id": self.task_id, + "name": self.name, + "armada_channel_args": self.armada_channel_args.serialize(), + "job_service_channel_args": self.job_service_channel_args.serialize(), + "armada_queue": self.armada_queue, + "job_request_items": self.job_request_items, + "lookout_url_template": self.lookout_url_template, + "poll_interval": self.poll_interval, + } + def execute(self, context) -> None: """ Executes the Armada Operator. Only meant to be called by airflow. @@ -156,6 +175,7 @@ def execute(self, context) -> None: armada_queue=self.armada_queue, job_set_id=context["run_id"], airflow_task_name=self.name, + poll_interval=self.poll_interval, ), method_name="resume_job_complete", kwargs={"job_id": job_id}, @@ -216,6 +236,7 @@ class ArmadaJobCompleteTrigger(BaseTrigger): :param job_set_id: The ID of the job set. :param airflow_task_name: Name of the airflow task to which this trigger belongs. + :param poll_interval: How often to poll jobservice to get status. :return: An armada job complete trigger instance. """ @@ -226,6 +247,7 @@ def __init__( armada_queue: str, job_set_id: str, airflow_task_name: str, + poll_interval: int = 30, ) -> None: super().__init__() self.job_id = job_id @@ -233,6 +255,7 @@ def __init__( self.armada_queue = armada_queue self.job_set_id = job_set_id self.airflow_task_name = airflow_task_name + self.poll_interval = poll_interval def serialize(self) -> tuple: return ( @@ -243,9 +266,21 @@ def serialize(self) -> tuple: "armada_queue": self.armada_queue, "job_set_id": self.job_set_id, "airflow_task_name": self.airflow_task_name, + "poll_interval": self.poll_interval, }, ) + def __eq__(self, o): + return ( + self.task_id == o.task_id + and self.job_id == o.job_id + and self.job_service_channel_args == o.job_service_channel_args + and self.armada_queue == o.armada_queue + and self.job_set_id == o.job_set_id + and self.airflow_task_name == o.airflow_task_name + and self.poll_interval == o.poll_interval + ) + async def run(self): """ Runs the trigger. Meant to be called by an airflow triggerer process. @@ -255,12 +290,12 @@ async def run(self): ) job_state, job_message = await search_for_job_complete_async( - job_service_client=job_service_client, armada_queue=self.armada_queue, job_set_id=self.job_set_id, airflow_task_name=self.airflow_task_name, job_id=self.job_id, - poll_interval=self.poll_interval, + job_service_client=job_service_client, log=self.log, + poll_interval=self.poll_interval, ) yield TriggerEvent({"job_state": job_state, "job_message": job_message}) diff --git a/third_party/airflow/armada/operators/utils.py b/third_party/airflow/armada/operators/utils.py index e3c68beb321..1ab7fa35d04 100644 --- a/third_party/airflow/armada/operators/utils.py +++ b/third_party/airflow/armada/operators/utils.py @@ -217,6 +217,7 @@ async def search_for_job_complete_async( job_id: str, job_service_client: JobServiceAsyncIOClient, log, + poll_interval: int, time_out_for_failure: int = 7200, ) -> Tuple[JobState, str]: """ @@ -231,6 +232,7 @@ async def search_for_job_complete_async( :param job_id: The name of the job id that armada assigns to it :param job_service_client: A JobServiceClient that is used for polling. It is optional only for testing + :param poll_interval: How often to poll jobservice to get status. :param time_out_for_failure: The amount of time a job can be in job_id_not_found before we decide it was a invalid job @@ -251,7 +253,7 @@ async def search_for_job_complete_async( job_state = job_state_from_pb(job_status_return.state) log.debug(f"Got job state '{job_state.name}' for job {job_id}") - await asyncio.sleep(3) + await asyncio.sleep(poll_interval) if job_state == JobState.SUCCEEDED: job_message = f"Armada {airflow_task_name}:{job_id} succeeded" diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index f1196307227..dc64cdc76b2 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -57,7 +57,7 @@ def submit_sleep_job(): with DAG( dag_id="big_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) as dag: diff --git a/third_party/airflow/tests/unit/test_airflow_operator_mock.py b/third_party/airflow/tests/unit/test_airflow_operator_mock.py index 4634e644795..1ab2d37ced1 100644 --- a/third_party/airflow/tests/unit/test_airflow_operator_mock.py +++ b/third_party/airflow/tests/unit/test_airflow_operator_mock.py @@ -170,7 +170,7 @@ def test_annotate_job_request_items(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) @@ -204,7 +204,7 @@ def test_parameterize_armada_operator(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) diff --git a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py new file mode 100644 index 00000000000..0f156ed177e --- /dev/null +++ b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py @@ -0,0 +1,171 @@ +import copy + +import pytest + +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) +from armada.operators.armada_deferrable import ArmadaDeferrableOperator +from armada.operators.grpc import CredentialsCallback + + +def test_serialize_armada_deferrable(): + grpc_chan_args = { + "target": "localhost:443", + "credentials_callback_args": { + "module_name": "channel_test", + "function_name": "get_credentials", + "function_kwargs": { + "example_arg": "test", + }, + }, + } + + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="busybox", + args=["sleep", "10s"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + ), + ) + ], + ) + + job_requests = [ + submit_pb2.JobSubmitRequestItem( + priority=1, + pod_spec=pod, + namespace="personal-anonymous", + annotations={"armadaproject.io/hello": "world"}, + ) + ] + + source = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test task", + armada_channel_args=grpc_chan_args, + job_service_channel_args=grpc_chan_args, + armada_queue="test-queue", + job_request_items=job_requests, + lookout_url_template="https://lookout.test.domain/", + poll_interval=5, + ) + + serialized = source.serialize() + assert serialized["name"] == source.name + + reconstituted = ArmadaDeferrableOperator(**serialized) + assert reconstituted == source + + +get_lookout_url_test_cases = [ + ( + "http://localhost:8089/jobs?job_id=", + "test_id", + "http://localhost:8089/jobs?job_id=test_id", + ), + ( + "https://lookout.armada.domain/jobs?job_id=", + "test_id", + "https://lookout.armada.domain/jobs?job_id=test_id", + ), + ("", "test_id", ""), + (None, "test_id", ""), +] + + +@pytest.mark.parametrize( + "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases +) +def test_get_lookout_url(lookout_url_template, job_id, expected_url): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template=lookout_url_template, + ) + + assert operator._get_lookout_url(job_id) == expected_url + + +def test_deepcopy_operator(): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def test_deepcopy_operator_with_grpc_credentials_callback(): + armada_channel_args = { + "target": "127.0.0.1:50051", + "credentials_callback_args": { + "module_name": "tests.unit.test_armada_operator", + "function_name": "__example_test_callback", + "function_kwargs": { + "test_arg": "fake_arg", + }, + }, + } + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def __example_test_callback(foo=None): + return f"fake_cred {foo}" + + +def test_credentials_callback(): + callback = CredentialsCallback( + module_name="test_armada_operator", + function_name="__example_test_callback", + function_kwargs={"foo": "bar"}, + ) + + result = callback.call() + assert result == "fake_cred bar" diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py index 83cc3e220aa..a842fa994d3 100644 --- a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py +++ b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py @@ -71,6 +71,7 @@ async def test_failed_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.FAILED assert ( @@ -89,6 +90,7 @@ async def test_successful_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded" @@ -104,6 +106,7 @@ async def test_cancelled_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.CANCELLED assert job_complete[1] == "Armada test:test_cancelled cancelled" @@ -119,6 +122,7 @@ async def test_job_id_not_found(js_aio_client): time_out_for_failure=5, job_service_client=js_aio_client, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.JOB_ID_NOT_FOUND assert ( @@ -142,6 +146,7 @@ async def test_error_retry(js_aio_retry_client): job_service_client=js_aio_retry_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded" From 291ef41abf2c4425abbf46bee713b7caef5432ea Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Mon, 18 Sep 2023 11:58:26 -0600 Subject: [PATCH 9/9] Release Airflow Operator v0.5.6 (#2979) Signed-off-by: Rich Scott --- third_party/airflow/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index aa8296d46cb..d3fb7abfa6f 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "armada_airflow" -version = "0.5.5" +version = "0.5.6" description = "Armada Airflow Operator" requires-python = ">=3.7" # Note(JayF): This dependency value is not suitable for release. Whatever