diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index 9d5b0690d4c..75936ce42a3 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -9,28 +9,13 @@ queueRefreshPeriod: 10s disableSubmitCheck: false metrics: port: 9000 + jobStateMetricsResetInterval: 12h refreshInterval: 30s - metrics: - scheduleCycleTimeHistogramSettings: - start: 10.0 - factor: 1.1 - count: 110 - reconcileCycleTimeHistogramSettings: - start: 10.0 - factor: 1.1 - count: 110 -schedulerMetrics: trackedResourceNames: - "cpu" - "memory" - "ephemeral-storage" - "nvidia.com/gpu" - resourceRenaming: - nvidia.com/gpu: "gpu" - amd.com/gpu: "gpu" - ephemeral-storage: "ephemeralStorage" - matchedRegexIndexByErrorMessageCacheSize: 100 - resetInterval: "1h" pulsar: URL: "pulsar://pulsar:6650" jobsetEventsTopic: "events" diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 29709313ac9..486d47cde1c 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -12,7 +12,7 @@ This class provides integration with Airflow and Armada ## armada.operators.armada module -### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, \*\*kwargs) +### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, dry_run=False, \*\*kwargs) Bases: `BaseOperator`, `LoggingMixin` An Airflow operator that manages Job submission to Armada. @@ -33,7 +33,7 @@ and handles job cancellation if the Airflow task is killed. * **armada_queue** (*str*) – - * **job_request** (*JobSubmitRequestItem*) – + * **job_request** (*JobSubmitRequestItem** | **Callable**[**[**Context**, **jinja2.Environment**]**, **JobSubmitRequestItem**]*) – * **job_set_prefix** (*Optional**[**str**]*) – @@ -57,8 +57,9 @@ and handles job cancellation if the Airflow task is killed. * **job_acknowledgement_timeout** (*int*) – + * **dry_run** (*bool*) – + -#### _property_ client(_: ArmadaClien_ ) #### execute(context) Submits the job to Armada and polls for completion. @@ -76,6 +77,10 @@ Submits the job to Armada and polls for completion. +#### _property_ hook(_: ArmadaHoo_ ) + +#### lookout_url(job_id) + #### on_kill() Override this method to clean up subprocesses when a task instance gets killed. @@ -89,6 +94,8 @@ operator needs to be cleaned up, or it will leave ghost processes behind. +#### operator_extra_links(_: Collection[BaseOperatorLink_ _ = (LookoutLink(),_ ) + #### _property_ pod_manager(_: KubernetesPodLogManage_ ) #### render_template_fields(context, jinja_env=None) @@ -117,6 +124,8 @@ Args: #### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ ) + +#### template_fields_renderers(_: Dict[str, str_ _ = {'job_request': 'py'_ ) Initializes a new ArmadaOperator. @@ -132,7 +141,7 @@ Initializes a new ArmadaOperator. * **armada_queue** (*str*) – The name of the Armada queue to which the job will be submitted. - * **job_request** (*JobSubmitRequestItem*) – The job to be submitted to Armada. + * **job_request** (*JobSubmitRequestItem** | **Callable**[**[**Context**, **jinja2.Environment**]**, **JobSubmitRequestItem**]*) – The job to be submitted to Armada. * **job_set_prefix** (*Optional**[**str**]*) – A string to prepend to the jobSet name. @@ -156,10 +165,39 @@ for asynchronous execution. :param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be acknowledged by Armada. :type job_acknowledgement_timeout: int +:param dry_run: Run Operator in dry-run mode - render Armada request and terminate. +:type dry_run: bool :param kwargs: Additional keyword arguments to pass to the BaseOperator. -### armada.operators.armada.log_exceptions(method) +### _class_ armada.operators.armada.LookoutLink() +Bases: `BaseOperatorLink` + + +#### get_link(operator, \*, ti_key) +Link to external system. + +Note: The old signature of this function was `(self, operator, dttm: datetime)`. That is still +supported at runtime but is deprecated. + + +* **Parameters** + + + * **operator** (*BaseOperator*) – The Airflow operator object this link is associated to. + + + * **ti_key** (*TaskInstanceKey*) – TaskInstance ID to return link for. + + + +* **Returns** + + link to external system + + + +#### name(_ = 'Lookout_ ) ## armada.triggers.armada module ## armada.auth module @@ -176,18 +214,10 @@ Bases: `Protocol` str - -#### serialize() - -* **Return type** - - *Tuple*[str, *Dict*[str, *Any*]] - - ## armada.model module -### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None, auth_details=None) +### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None) Bases: `object` @@ -197,32 +227,31 @@ Bases: `object` * **target** (*str*) – - * **options** (*Sequence**[**Tuple**[**str**, **Any**]**] **| **None*) – + * **options** (*Optional**[**Sequence**[**Tuple**[**str**, **Any**]**]**]*) – - * **compression** (*Compression** | **None*) – + * **compression** (*Optional**[**grpc.Compression**]*) – - * **auth** (*AuthMetadataPlugin** | **None*) – + * **auth** (*Optional**[**grpc.AuthMetadataPlugin**]*) – - * **auth_details** (*Dict**[**str**, **Any**] **| **None*) – +#### _static_ deserialize(data, version) +* **Parameters** -#### aio_channel() - -* **Return type** + + * **data** (*dict**[**str**, **Any**]*) – - *Channel* + * **version** (*int*) – -#### channel() * **Return type** - *Channel* + *GrpcChannelArgs* @@ -231,3 +260,50 @@ Bases: `object` * **Return type** *Dict*[str, *Any*] + + + +### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', submit_time: 'DateTime', cluster: 'Optional[str]' = None, last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN') +Bases: `object` + + +* **Parameters** + + + * **armada_queue** (*str*) – + + + * **job_id** (*str*) – + + + * **job_set_id** (*str*) – + + + * **submit_time** (*DateTime*) – + + + * **cluster** (*str** | **None*) – + + + * **last_log_time** (*DateTime** | **None*) – + + + * **job_state** (*str*) – + + + +#### armada_queue(_: st_ ) + +#### cluster(_: str | Non_ _ = Non_ ) + +#### job_id(_: st_ ) + +#### job_set_id(_: st_ ) + +#### job_state(_: st_ _ = 'UNKNOWN_ ) + +#### last_log_time(_: DateTime | Non_ _ = Non_ ) + +#### _property_ state(_: JobStat_ ) + +#### submit_time(_: DateTim_ ) diff --git a/go.mod b/go.mod index e2254960d02..d23c92d5a56 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/go-openapi/spec v0.20.14 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.4 - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 @@ -35,7 +35,7 @@ require ( github.com/oklog/ulid v1.3.1 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.17.0 + github.com/prometheus/client_golang v1.19.1 github.com/renstrom/shortuuid v3.0.0+incompatible github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 @@ -78,7 +78,7 @@ require ( github.com/magefile/mage v1.14.0 github.com/minio/highwayhash v1.0.2 github.com/openconfig/goyang v1.2.0 - github.com/prometheus/common v0.45.0 + github.com/prometheus/common v0.48.0 github.com/redis/go-redis/extra/redisprometheus/v9 v9.0.5 github.com/redis/go-redis/v9 v9.5.1 github.com/segmentio/fasthash v1.0.3 @@ -161,7 +161,6 @@ require ( github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/microcosm-cc/bluemonday v1.0.25 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index 2ee3f5d4e29..8aa19f8845b 100644 --- a/go.sum +++ b/go.sum @@ -239,8 +239,9 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -372,8 +373,6 @@ github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZ github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-zglob v0.0.4 h1:LQi2iOm0/fGgu80AioIJ/1j9w9Oh+9DZ39J4VAGzHQM= github.com/mattn/go-zglob v0.0.4/go.mod h1:MxxjyoXXnMxfIpxTK2GAkw1w8glPsQILx3N5wrKakiY= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= github.com/microcosm-cc/bluemonday v1.0.25 h1:4NEwSfiJ+Wva0VxN5B8OwMicaJvD8r9tlJWm9rtloEg= github.com/microcosm-cc/bluemonday v1.0.25/go.mod h1:ZIOjCQp1OrzBBPIJmfX4qDYFuhU02nx4bn030ixfHLE= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= @@ -437,13 +436,13 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/cachecontrol v0.1.0 h1:yJMy84ti9h/+OEWa752kBTKv4XC30OtVVHYv/8cTqKc= github.com/pquerna/cachecontrol v0.1.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= -github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= -github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/redis/go-redis/extra/redisprometheus/v9 v9.0.5 h1:kvl0LOTQD23VR1R7A9vDti9msfV6mOE2+j6ngYkFsfg= diff --git a/internal/scheduler/configuration/configuration.go b/internal/scheduler/configuration/configuration.go index 98fada635aa..c6ae533b53a 100644 --- a/internal/scheduler/configuration/configuration.go +++ b/internal/scheduler/configuration/configuration.go @@ -32,10 +32,7 @@ type Configuration struct { // Configuration controlling leader election Leader LeaderConfig // Configuration controlling metrics - Metrics LegacyMetricsConfig - // Configuration for new scheduler metrics. - // Due to replace metrics configured via the above entry. - SchedulerMetrics MetricsConfig + Metrics MetricsConfig // Scheduler configuration (this is shared with the old scheduler) Scheduling SchedulingConfig Auth authconfig.AuthConfig @@ -69,28 +66,6 @@ func (c Configuration) Validate() error { return validate.Struct(c) } -type MetricsConfig struct { - // If true, disable metric collection and publishing. - Disabled bool - // Regexes used for job error categorisation. - // Specifically, the subCategory label for job failure counters is the first regex that matches the job error. - // If no regex matches, the subCategory label is the empty string. - TrackedErrorRegexes []string - // Metrics are exported for these resources. - TrackedResourceNames []v1.ResourceName - // Optionally rename resources in exported metrics. - // E.g., if ResourceRenaming["nvidia.com/gpu"] = "gpu", then metrics for resource "nvidia.com/gpu" use resource name "gpu" instead. - // This can be used to avoid illegal Prometheus metric names (e.g., for "nvidia.com/gpu" as "/" is not allowed). - // Allowed characters in resource names are [a-zA-Z_:][a-zA-Z0-9_:]* - // It can also be used to track multiple resources within the same metric, e.g., "nvidia.com/gpu" and "amd.com/gpu". - ResourceRenaming map[v1.ResourceName]string - // The first matching regex of each error message is cached in an LRU cache. - // This setting controls the cache size. - MatchedRegexIndexByErrorMessageCacheSize uint64 - // Reset metrics this often. Resetting periodically ensures inactive time series are garbage-collected. - ResetInterval time.Duration -} - type LeaderConfig struct { // Valid modes are "standalone" or "kubernetes" Mode string `validate:"required"` @@ -128,16 +103,16 @@ type HttpConfig struct { Port int `validate:"required"` } -// TODO: ALl this needs to be unified with MetricsConfig -type LegacyMetricsConfig struct { - Port uint16 - RefreshInterval time.Duration - Metrics SchedulerMetricsConfig -} - -type SchedulerMetricsConfig struct { - ScheduleCycleTimeHistogramSettings HistogramConfig - ReconcileCycleTimeHistogramSettings HistogramConfig +type MetricsConfig struct { + Port uint16 + RefreshInterval time.Duration + JobStateMetricsResetInterval time.Duration + // Regexes used for job error categorisation. + // Specifically, the subCategory label for job failure counters is the first regex that matches the job error. + // If no regex matches, the subCategory label is the empty string. + TrackedErrorRegexes []string + // Metrics are exported for these resources. + TrackedResourceNames []v1.ResourceName } type HistogramConfig struct { diff --git a/internal/scheduler/context/scheduling.go b/internal/scheduler/context/scheduling.go index 5917ca79d25..40be0d7aaa8 100644 --- a/internal/scheduler/context/scheduling.go +++ b/internal/scheduler/context/scheduling.go @@ -378,3 +378,17 @@ func (sctx *SchedulingContext) AllocatedByQueueAndPriority() map[string]schedule } return rv } + +// FairnessError returns the cumulative delta between adjusted fair share and actual share for all users who +// are below their fair share +func (sctx *SchedulingContext) FairnessError() float64 { + fairnessError := 0.0 + for _, qctx := range sctx.QueueSchedulingContexts { + actualShare := sctx.FairnessCostProvider.UnweightedCostFromQueue(qctx) + delta := qctx.AdjustedFairShare - actualShare + if delta > 0 { + fairnessError += delta + } + } + return fairnessError +} diff --git a/internal/scheduler/context/scheduling_test.go b/internal/scheduler/context/scheduling_test.go index 6eafded6523..8bef8d51ef6 100644 --- a/internal/scheduler/context/scheduling_test.go +++ b/internal/scheduler/context/scheduling_test.go @@ -1,6 +1,7 @@ package context import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -58,21 +59,11 @@ func TestSchedulingContextAccounting(t *testing.T) { } func TestCalculateFairShares(t *testing.T) { - zeroCpu := schedulerobjects.ResourceList{ - Resources: map[string]resource.Quantity{"cpu": resource.MustParse("0")}, - } - oneCpu := schedulerobjects.ResourceList{ - Resources: map[string]resource.Quantity{"cpu": resource.MustParse("1")}, - } - fortyCpu := schedulerobjects.ResourceList{ - Resources: map[string]resource.Quantity{"cpu": resource.MustParse("40")}, - } - oneHundredCpu := schedulerobjects.ResourceList{ - Resources: map[string]resource.Quantity{"cpu": resource.MustParse("100")}, - } - oneThousandCpu := schedulerobjects.ResourceList{ - Resources: map[string]resource.Quantity{"cpu": resource.MustParse("1000")}, - } + zeroCpu := cpu(0) + oneCpu := cpu(1) + fortyCpu := cpu(40) + oneHundredCpu := cpu(100) + oneThousandCpu := cpu(1000) tests := map[string]struct { availableResources schedulerobjects.ResourceList queueCtxs map[string]*QueueSchedulingContext @@ -208,6 +199,66 @@ func TestCalculateFairShares(t *testing.T) { } } +func TestCalculateFairnessError(t *testing.T) { + tests := map[string]struct { + availableResources schedulerobjects.ResourceList + queueCtxs map[string]*QueueSchedulingContext + expected float64 + }{ + "one queue, no error": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Allocated: cpu(50), AdjustedFairShare: 0.5}, + }, + expected: 0, + }, + "two queues, no error": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Allocated: cpu(50), AdjustedFairShare: 0.5}, + "queueB": {Allocated: cpu(50), AdjustedFairShare: 0.5}, + }, + expected: 0, + }, + "one queue with error": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Allocated: cpu(40), AdjustedFairShare: 0.5}, + }, + expected: 0.1, + }, + "two queues with error": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Allocated: cpu(40), AdjustedFairShare: 0.5}, + "queueB": {Allocated: cpu(10), AdjustedFairShare: 0.5}, + }, + expected: 0.5, + }, + "above fair share is not counted": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Allocated: cpu(100), AdjustedFairShare: 0.5}, + }, + expected: 0.0, + }, + "empty": { + availableResources: cpu(100), + queueCtxs: map[string]*QueueSchedulingContext{}, + expected: 0.0, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + fairnessCostProvider, err := fairness.NewDominantResourceFairness(tc.availableResources, configuration.SchedulingConfig{DominantResourceFairnessResourcesToConsider: []string{"cpu"}}) + require.NoError(t, err) + sctx := NewSchedulingContext("pool", fairnessCostProvider, nil, tc.availableResources) + sctx.QueueSchedulingContexts = tc.queueCtxs + assert.InDelta(t, tc.expected, sctx.FairnessError(), 0.00001) + }) + } +} + func testNSmallCpuJobSchedulingContext(queue, priorityClassName string, n int) []*JobSchedulingContext { rv := make([]*JobSchedulingContext, n) for i := 0; i < n; i++ { @@ -226,3 +277,9 @@ func testSmallCpuJobSchedulingContext(queue, priorityClassName string) *JobSched GangInfo: EmptyGangInfo(job), } } + +func cpu(n int) schedulerobjects.ResourceList { + return schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse(fmt.Sprintf("%d", n))}, + } +} diff --git a/internal/scheduler/jobdb/job_run.go b/internal/scheduler/jobdb/job_run.go index ea00a78ef1c..c325f20d19c 100644 --- a/internal/scheduler/jobdb/job_run.go +++ b/internal/scheduler/jobdb/job_run.go @@ -280,6 +280,13 @@ func (run *JobRun) Executor() string { return run.executor } +// WithExecutor returns a copy of the job run with the executor updated. +func (run *JobRun) WithExecutor(executor string) *JobRun { + run = run.DeepCopy() + run.executor = executor + return run +} + // NodeId returns the id of the node to which the JobRun is assigned. func (run *JobRun) NodeId() string { return run.nodeId @@ -290,11 +297,25 @@ func (run *JobRun) Pool() string { return run.pool } +// WithPool returns a copy of the job run with the pool updated +func (run *JobRun) WithPool(pool string) *JobRun { + run = run.DeepCopy() + run.pool = pool + return run +} + // NodeName returns the name of the node to which the JobRun is assigned. func (run *JobRun) NodeName() string { return run.nodeName } +// WithNodeName returns a copy of the job run with the node name updated. +func (run *JobRun) WithNodeName(nodeName string) *JobRun { + run = run.DeepCopy() + run.nodeName = nodeName + return run +} + func (run *JobRun) ScheduledAtPriority() *int32 { return run.scheduledAtPriority } diff --git a/internal/scheduler/metrics/constants.go b/internal/scheduler/metrics/constants.go new file mode 100644 index 00000000000..dc2f070b923 --- /dev/null +++ b/internal/scheduler/metrics/constants.go @@ -0,0 +1,29 @@ +package metrics + +const ( + + // common prefix for all metric names + prefix = "armada_scheduler_" + + // Prometheus Labels + poolLabel = "pool" + queueLabel = "queue" + priorityClassLabel = "priority_class" + nodeLabel = "node" + clusterLabel = "cluster" + errorCategoryLabel = "category" + errorSubcategoryLabel = "subcategory" + stateLabel = "state" + priorStateLabel = "priorState" + resourceLabel = "resource" + + // Job state strings + queued = "queued" + running = "running" + pending = "pending" + cancelled = "cancelled" + leased = "leased" + preempted = "preempted" + failed = "failed" + succeeded = "succeeded" +) diff --git a/internal/scheduler/metrics/cycle_metrics.go b/internal/scheduler/metrics/cycle_metrics.go new file mode 100644 index 00000000000..296c491826b --- /dev/null +++ b/internal/scheduler/metrics/cycle_metrics.go @@ -0,0 +1,185 @@ +package metrics + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" +) + +var ( + poolAndQueueLabels = []string{poolLabel, queueLabel} + queueAndPriorityClassLabels = []string{queueLabel, priorityClassLabel} +) + +type cycleMetrics struct { + scheduledJobs *prometheus.CounterVec + premptedJobs *prometheus.CounterVec + consideredJobs *prometheus.GaugeVec + fairShare *prometheus.GaugeVec + adjustedFairShare *prometheus.GaugeVec + actualShare *prometheus.GaugeVec + fairnessError *prometheus.GaugeVec + demand *prometheus.GaugeVec + cappedDemand *prometheus.GaugeVec + scheduleCycleTime prometheus.Histogram + reconciliationCycleTime prometheus.Histogram +} + +func newCycleMetrics() *cycleMetrics { + return &cycleMetrics{ + scheduledJobs: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "scheduled_jobs", + Help: "Number of events scheduled", + }, + queueAndPriorityClassLabels, + ), + + premptedJobs: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "preempted_jobs", + Help: "Number of jobs preempted", + }, + queueAndPriorityClassLabels, + ), + + consideredJobs: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "considered_jobs", + Help: "Number of jobs considered", + }, + poolAndQueueLabels, + ), + + fairShare: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "fair_share", + Help: "Fair share of each queue", + }, + poolAndQueueLabels, + ), + + adjustedFairShare: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "adjusted_fair_share", + Help: "Adjusted Fair share of each queue", + }, + poolAndQueueLabels, + ), + + actualShare: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "actual_share", + Help: "Actual Fair share of each queue", + }, + poolAndQueueLabels, + ), + + demand: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "demand", + Help: "Demand of each queue", + }, + poolAndQueueLabels, + ), + + cappedDemand: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "capped_demand", + Help: "Capped Demand of each queue and pool. This differs from demand in that it limits demand by scheduling constraints", + }, + poolAndQueueLabels, + ), + + fairnessError: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: prefix + "fairness_error", + Help: "Cumulative delta between adjusted fair share and actual share for all users who are below their fair share", + }, + []string{poolLabel}, + ), + + scheduleCycleTime: prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: prefix + "schedule_cycle_times", + Help: "Cycle time when in a scheduling round.", + Buckets: prometheus.ExponentialBuckets(10.0, 1.1, 110), + }, + ), + + reconciliationCycleTime: prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: prefix + "reconciliation_cycle_times", + Help: "Cycle time when in a scheduling round.", + Buckets: prometheus.ExponentialBuckets(10.0, 1.1, 110), + }, + ), + } +} + +func (m *cycleMetrics) ReportScheduleCycleTime(cycleTime time.Duration) { + m.scheduleCycleTime.Observe(float64(cycleTime.Milliseconds())) +} + +func (m *cycleMetrics) ReportReconcileCycleTime(cycleTime time.Duration) { + m.reconciliationCycleTime.Observe(float64(cycleTime.Milliseconds())) +} + +func (m *cycleMetrics) ReportSchedulerResult(result schedulerresult.SchedulerResult) { + // Metrics that depend on pool + for _, schedContext := range result.SchedulingContexts { + pool := schedContext.Pool + for queue, queueContext := range schedContext.QueueSchedulingContexts { + jobsConsidered := float64(len(queueContext.UnsuccessfulJobSchedulingContexts) + len(queueContext.SuccessfulJobSchedulingContexts)) + actualShare := schedContext.FairnessCostProvider.UnweightedCostFromQueue(queueContext) + demand := schedContext.FairnessCostProvider.UnweightedCostFromAllocation(queueContext.Demand) + cappedDemand := schedContext.FairnessCostProvider.UnweightedCostFromAllocation(queueContext.CappedDemand) + + m.consideredJobs.WithLabelValues(pool, queue).Set(jobsConsidered) + m.fairShare.WithLabelValues(pool, queue).Set(queueContext.FairShare) + m.adjustedFairShare.WithLabelValues(pool, queue).Set(queueContext.AdjustedFairShare) + m.actualShare.WithLabelValues(pool, queue).Set(actualShare) + m.demand.WithLabelValues(pool, queue).Set(demand) + m.cappedDemand.WithLabelValues(pool, queue).Set(cappedDemand) + } + m.fairnessError.WithLabelValues(pool).Set(schedContext.FairnessError()) + } + + for _, jobCtx := range result.ScheduledJobs { + m.scheduledJobs.WithLabelValues(jobCtx.Job.Queue(), jobCtx.PriorityClassName).Inc() + } + + for _, jobCtx := range result.PreemptedJobs { + m.premptedJobs.WithLabelValues(jobCtx.Job.Queue(), jobCtx.PriorityClassName).Inc() + } +} + +func (m *cycleMetrics) describe(ch chan<- *prometheus.Desc) { + m.scheduledJobs.Describe(ch) + m.premptedJobs.Describe(ch) + m.consideredJobs.Describe(ch) + m.fairShare.Describe(ch) + m.adjustedFairShare.Describe(ch) + m.actualShare.Describe(ch) + m.fairnessError.Describe(ch) + m.demand.Describe(ch) + m.cappedDemand.Describe(ch) + m.scheduleCycleTime.Describe(ch) + m.reconciliationCycleTime.Describe(ch) +} + +func (m *cycleMetrics) collect(ch chan<- prometheus.Metric) { + m.scheduledJobs.Collect(ch) + m.premptedJobs.Collect(ch) + m.consideredJobs.Collect(ch) + m.fairShare.Collect(ch) + m.adjustedFairShare.Collect(ch) + m.actualShare.Collect(ch) + m.fairnessError.Collect(ch) + m.demand.Collect(ch) + m.cappedDemand.Collect(ch) + m.scheduleCycleTime.Collect(ch) + m.reconciliationCycleTime.Collect(ch) +} diff --git a/internal/scheduler/metrics/cycle_metrics_test.go b/internal/scheduler/metrics/cycle_metrics_test.go new file mode 100644 index 00000000000..2f86dcb9c91 --- /dev/null +++ b/internal/scheduler/metrics/cycle_metrics_test.go @@ -0,0 +1,85 @@ +package metrics + +import ( + "fmt" + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/armadaproject/armada/internal/scheduler/configuration" + "github.com/armadaproject/armada/internal/scheduler/context" + "github.com/armadaproject/armada/internal/scheduler/fairness" + "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" + "github.com/armadaproject/armada/internal/scheduler/testfixtures" +) + +const epsilon = 1e-6 + +func TestReportStateTransitions(t *testing.T) { + fairnessCostProvider, err := fairness.NewDominantResourceFairness( + cpu(100), + configuration.SchedulingConfig{DominantResourceFairnessResourcesToConsider: []string{"cpu"}}) + require.NoError(t, err) + result := schedulerresult.SchedulerResult{ + SchedulingContexts: []*context.SchedulingContext{ + { + Pool: "pool1", + FairnessCostProvider: fairnessCostProvider, + QueueSchedulingContexts: map[string]*context.QueueSchedulingContext{ + "queue1": { + Allocated: cpu(10), + Demand: cpu(20), + CappedDemand: cpu(15), + AdjustedFairShare: 0.15, + SuccessfulJobSchedulingContexts: map[string]*context.JobSchedulingContext{ + "job1": { + Job: testfixtures.Test1Cpu4GiJob("queue1", testfixtures.PriorityClass0), + }, + "job2": { + Job: testfixtures.Test1Cpu4GiJob("queue1", testfixtures.PriorityClass0), + }, + }, + UnsuccessfulJobSchedulingContexts: map[string]*context.JobSchedulingContext{ + "job2": { + Job: testfixtures.Test1Cpu4GiJob("queue1", testfixtures.PriorityClass0), + }, + }, + }, + }, + }, + }, + } + + m := newCycleMetrics() + m.ReportSchedulerResult(result) + + poolQueue := []string{"pool1", "queue1"} + + consideredJobs := testutil.ToFloat64(m.consideredJobs.WithLabelValues(poolQueue...)) + assert.Equal(t, 3.0, consideredJobs, "consideredJobs") + + allocated := testutil.ToFloat64(m.actualShare.WithLabelValues(poolQueue...)) + assert.InDelta(t, 0.1, allocated, epsilon, "allocated") + + demand := testutil.ToFloat64(m.demand.WithLabelValues(poolQueue...)) + assert.InDelta(t, 0.2, demand, epsilon, "demand") + + cappedDemand := testutil.ToFloat64(m.cappedDemand.WithLabelValues(poolQueue...)) + assert.InDelta(t, 0.15, cappedDemand, epsilon, "cappedDemand") + + adjustedFairShare := testutil.ToFloat64(m.adjustedFairShare.WithLabelValues(poolQueue...)) + assert.InDelta(t, 0.15, adjustedFairShare, epsilon, "adjustedFairShare") + + fairnessError := testutil.ToFloat64(m.fairnessError.WithLabelValues("pool1")) + assert.InDelta(t, 0.05, fairnessError, epsilon, "fairnessError") +} + +func cpu(n int) schedulerobjects.ResourceList { + return schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse(fmt.Sprintf("%d", n))}, + } +} diff --git a/internal/scheduler/metrics/metrics.go b/internal/scheduler/metrics/metrics.go index bdf4097b2b6..655c3449d9d 100644 --- a/internal/scheduler/metrics/metrics.go +++ b/internal/scheduler/metrics/metrics.go @@ -4,552 +4,56 @@ import ( "regexp" "time" - "github.com/google/uuid" - lru "github.com/hashicorp/golang-lru" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" - - "github.com/armadaproject/armada/internal/common/armadacontext" - "github.com/armadaproject/armada/internal/scheduler/configuration" - schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" - "github.com/armadaproject/armada/internal/scheduler/jobdb" - "github.com/armadaproject/armada/pkg/armadaevents" -) - -const ( - namespace = "armada" - subsystem = "scheduler" - - podUnschedulable = "podUnschedulable" - leaseExpired = "leaseExpired" - podError = "podError" - podLeaseReturned = "podLeaseReturned" - podTerminated = "podTerminated" - jobRunPreempted = "jobRunPreempted" - - queued = "queued" - running = "running" - pending = "pending" - cancelled = "cancelled" - leased = "leased" - preempted = "preempted" - failed = "failed" - succeeded = "succeeded" ) -// A valid metric name contains only: letters, digits(not as the first character), underscores, and colons. -// validated by the following regex -var metricNameValidationRegex = regexp.MustCompile(`^[a-zA-Z_:][a-zA-Z0-9_:]*$`) - +// Metrics is the top level scheduler metrics. type Metrics struct { - config configuration.MetricsConfig - - // For disabling metrics at runtime, e.g., if not leader. - disabled bool - - // Buffer used to avoid allocations when updating metrics. - buffer []string - - // Reset metrics periodically. - resetInterval time.Duration - timeOfMostRecentReset time.Time - - // Pre-compiled regexes for error categorisation. - errorRegexes []*regexp.Regexp - // Map from error message to the index of the first matching regex. - // Messages that match no regex map to -1. - matchedRegexIndexByErrorMessage *lru.Cache - - // Histogram of completed run durations by queue - completedRunDurations *prometheus.HistogramVec - - // Map from resource name to the counter and counterSeconds Vecs for that resource. - resourceCounters map[v1.ResourceName]*prometheus.CounterVec + *cycleMetrics + *jobStateMetrics } -func New(config configuration.MetricsConfig) (*Metrics, error) { - errorRegexes := make([]*regexp.Regexp, len(config.TrackedErrorRegexes)) - for i, errorRegex := range config.TrackedErrorRegexes { +func New(errorRegexes []string, trackedResourceNames []v1.ResourceName, jobStateMetricsResetInterval time.Duration) (*Metrics, error) { + compiledErrorRegexes := make([]*regexp.Regexp, len(errorRegexes)) + for i, errorRegex := range errorRegexes { if r, err := regexp.Compile(errorRegex); err != nil { return nil, errors.WithStack(err) } else { - errorRegexes[i] = r - } - } - - var matchedRegexIndexByError *lru.Cache - if config.MatchedRegexIndexByErrorMessageCacheSize > 0 { - var err error - matchedRegexIndexByError, err = lru.New(int(config.MatchedRegexIndexByErrorMessageCacheSize)) - if err != nil { - return nil, errors.WithStack(err) + compiledErrorRegexes[i] = r } } - return &Metrics{ - config: config, - - resetInterval: config.ResetInterval, - timeOfMostRecentReset: time.Now(), - - buffer: make([]string, 0, 9), - - errorRegexes: errorRegexes, - matchedRegexIndexByErrorMessage: matchedRegexIndexByError, - completedRunDurations: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: "job_run_completed_duration_seconds", - Help: "Time", - Buckets: prometheus.ExponentialBuckets( - 2, - 2, - 20), - ConstLabels: map[string]string{}, - }, - []string{"queue"}, - ), - resourceCounters: make(map[v1.ResourceName]*prometheus.CounterVec), + cycleMetrics: newCycleMetrics(), + jobStateMetrics: newJobStateMetrics(compiledErrorRegexes, trackedResourceNames, jobStateMetricsResetInterval), }, nil } -func (m *Metrics) Disable() { - if m != nil { - m.disabled = true - } +// DisableJobStateMetrics stops the jobStateMetrics from being produced. This is necessary because we only produce +// these metrics when we are leader in order to avoid double counting +func (m *Metrics) DisableJobStateMetrics() { + m.jobStateMetrics.disable() } -func (m *Metrics) Enable() { - if m != nil { - m.disabled = false - } +// EnableJobStateMetrics starts the jobStateMetrics +func (m *Metrics) EnableJobStateMetrics() { + m.jobStateMetrics.enable() } -func (m *Metrics) IsDisabled() bool { - if m == nil { - return true - } - if m.config.Disabled { - return true - } - return m.disabled +// JobStateMetricsEnabled returns true if job state metrics are enabled +func (m *Metrics) JobStateMetricsEnabled() bool { + return m.jobStateMetrics.isEnabled() } +// Describe is necessary to implement the prometheus.Collector interface func (m *Metrics) Describe(ch chan<- *prometheus.Desc) { - if m.IsDisabled() { - return - } - for _, metric := range m.resourceCounters { - metric.Describe(ch) - } - m.completedRunDurations.Describe(ch) + m.jobStateMetrics.describe(ch) + m.cycleMetrics.describe(ch) } -// Collect and then reset all metrics. -// Resetting ensures we do not build up a large number of counters over time. +// Collect is necessary to implement the prometheus.Collector interface func (m *Metrics) Collect(ch chan<- prometheus.Metric) { - if m.IsDisabled() { - return - } - for _, metric := range m.resourceCounters { - metric.Collect(ch) - } - // Reset metrics periodically. - t := time.Now() - if t.Sub(m.timeOfMostRecentReset) > m.resetInterval { - for _, metric := range m.resourceCounters { - metric.Reset() - } - m.timeOfMostRecentReset = t - } - m.completedRunDurations.Collect(ch) -} - -func (m *Metrics) UpdateMany( - ctx *armadacontext.Context, - jsts []jobdb.JobStateTransitions, - jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error, -) error { - for _, jst := range jsts { - if err := m.Update(ctx, jst, jobRunErrorsByRunId); err != nil { - return err - } - } - return nil -} - -func (m *Metrics) Update( - ctx *armadacontext.Context, - jst jobdb.JobStateTransitions, - jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error, -) error { - if jst.Queued { - if err := m.UpdateQueued(jst.Job); err != nil { - return err - } - } - if jst.Pending { - if err := m.UpdatePending(jst.Job); err != nil { - return err - } - } - if jst.Running { - if err := m.UpdateRunning(jst.Job); err != nil { - return err - } - } - if jst.Cancelled { - if err := m.UpdateCancelled(jst.Job); err != nil { - return err - } - } - if jst.Failed { - if err := m.UpdateFailed(ctx, jst.Job, jobRunErrorsByRunId); err != nil { - return err - } - } - if jst.Succeeded { - if err := m.UpdateSucceeded(jst.Job); err != nil { - return err - } - } - if jst.Preempted { - if err := m.UpdatePreempted(jst.Job); err != nil { - return err - } - } - // UpdateLeased is called by the scheduler directly once a job is leased. - // It is not called here to avoid double counting. - return nil -} - -func (m *Metrics) UpdateQueued(job *jobdb.Job) error { - labels := m.buffer[0:0] - labels = append(labels, "") // No priorState for queued. - labels = append(labels, queued) - labels = append(labels, "") // No category for queued. - labels = append(labels, "") // No subCategory for queued. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, 0) -} - -func (m *Metrics) UpdatePending(job *jobdb.Job) error { - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.PendingTime()) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, pending) - labels = append(labels, "") // No category for pending. - labels = append(labels, "") // No subCategory for pending. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdateCancelled(job *jobdb.Job) error { - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.TerminatedTime()) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, cancelled) - labels = append(labels, "") // No category for cancelled. - labels = append(labels, "") // No subCategory for cancelled. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdateFailed(ctx *armadacontext.Context, job *jobdb.Job, jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error) error { - category, subCategory := m.failedCategoryAndSubCategoryFromJob(ctx, job, jobRunErrorsByRunId) - if category == jobRunPreempted { - // It is safer to UpdatePreempted from preemption errors and not from the scheduler cycle result. - // e.g. The scheduler might decide to preempt a job, but before the job is preempted, it happens to succeed, - // in which case it should be reported as a success, not a preemption. - return m.UpdatePreempted(job) - } - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.TerminatedTime()) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, failed) - labels = append(labels, category) - labels = append(labels, subCategory) - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdateSucceeded(job *jobdb.Job) error { - labels := m.buffer[0:0] - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.TerminatedTime()) - labels = append(labels, priorState) - labels = append(labels, succeeded) - labels = append(labels, "") // No category for succeeded. - labels = append(labels, "") // No subCategory for succeeded. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdateLeased(jctx *schedulercontext.JobSchedulingContext) error { - job := jctx.Job - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, &jctx.Created) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, leased) - labels = append(labels, "") // No category for leased. - labels = append(labels, "") // No subCategory for leased. - labels = appendLabelsFromJob(labels, jctx.Job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdatePreempted(job *jobdb.Job) error { - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.PreemptedTime()) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, preempted) - labels = append(labels, "") // No category for preempted. - labels = append(labels, "") // No subCategory for preempted. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) UpdateRunning(job *jobdb.Job) error { - latestRun := job.LatestRun() - duration, priorState := stateDuration(job, latestRun, latestRun.RunningTime()) - labels := m.buffer[0:0] - labels = append(labels, priorState) - labels = append(labels, running) - labels = append(labels, "") // No category for running. - labels = append(labels, "") // No subCategory for running. - labels = appendLabelsFromJob(labels, job) - - return m.updateMetrics(labels, job, duration) -} - -func (m *Metrics) failedCategoryAndSubCategoryFromJob(ctx *armadacontext.Context, job *jobdb.Job, jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error) (category, subCategory string) { - run := job.LatestRun() - if run == nil { - return - } - - category, message := errorTypeAndMessageFromError(ctx, jobRunErrorsByRunId[run.Id()]) - i, ok := m.regexIndexFromErrorMessage(message) - if ok { - subCategory = m.config.TrackedErrorRegexes[i] - } - - return -} - -func (m *Metrics) regexIndexFromErrorMessage(message string) (int, bool) { - i, ok := m.cachedRegexIndexFromErrorMessage(message) - if !ok { - i, ok = m.indexOfFirstMatchingRegexFromErrorMessage(message) - if !ok { - // Use -1 to indicate that no regex matches. - i = -1 - } - if m.matchedRegexIndexByErrorMessage != nil { - m.matchedRegexIndexByErrorMessage.Add(message, i) - } - } - if i == -1 { - ok = false - } - return i, ok -} - -func (m *Metrics) cachedRegexIndexFromErrorMessage(message string) (int, bool) { - if m.matchedRegexIndexByErrorMessage == nil { - return 0, false - } - i, ok := m.matchedRegexIndexByErrorMessage.Get(message) - if !ok { - return 0, false - } - return i.(int), true -} - -func (m *Metrics) indexOfFirstMatchingRegexFromErrorMessage(message string) (int, bool) { - for i, r := range m.errorRegexes { - if r.MatchString(message) { - return i, true - } - } - return 0, false -} - -func appendLabelsFromJob(labels []string, job *jobdb.Job) []string { - executor := executorNameFromRun(job.LatestRun()) - pools := job.ResolvedPools() - pool := "" - if len(pools) > 0 { - pool = pools[0] - } - labels = append(labels, job.Queue()) - labels = append(labels, executor) - labels = append(labels, pool) - return labels -} - -func executorNameFromRun(run *jobdb.JobRun) string { - if run == nil { - // This case covers, e.g., jobs failing that have never been scheduled. - return "" - } - return run.Executor() -} - -func errorTypeAndMessageFromError(ctx *armadacontext.Context, err *armadaevents.Error) (string, string) { - if err == nil { - return "", "" - } - // The following errors relate to job run failures. - // We do not process JobRunPreemptedError as there is separate metric for preemption. - switch reason := err.Reason.(type) { - case *armadaevents.Error_PodUnschedulable: - return podUnschedulable, reason.PodUnschedulable.Message - case *armadaevents.Error_LeaseExpired: - return leaseExpired, "" - case *armadaevents.Error_PodError: - return podError, reason.PodError.Message - case *armadaevents.Error_PodLeaseReturned: - return podLeaseReturned, reason.PodLeaseReturned.Message - case *armadaevents.Error_JobRunPreemptedError: - return jobRunPreempted, "" - default: - ctx.Warnf("omitting name and message for unknown error type %T", err.Reason) - return "", "" - } -} - -func (m *Metrics) updateMetrics(labels []string, job *jobdb.Job, stateDuration time.Duration) error { - // update jobs and jobs-seconds metrics - jobs, jobsSeconds := m.counterVectorsFromResource(v1.ResourceName("jobs")) - if c, err := jobs.GetMetricWithLabelValues(labels[1:]...); err != nil { // we don't need priorState label here - return err - } else { - c.Add(1) - } - if c, err := jobsSeconds.GetMetricWithLabelValues(labels...); err != nil { - return err - } else { - c.Add(stateDuration.Seconds()) - } - - if job.HasRuns() && job.LatestRun().InTerminalState() { - m.completedRunDurations.WithLabelValues(job.Queue()).Observe(stateDuration.Seconds()) - } - - requests := job.ResourceRequirements().Requests - for _, resource := range m.config.TrackedResourceNames { - if r, ok := m.config.ResourceRenaming[resource]; ok { - resource = v1.ResourceName(r) - } - if !metricNameValidationRegex.MatchString(resource.String()) { - logrus.Warnf("Resource name is not valid for a metric name: %s", resource) - continue - } - metric, metricSeconds := m.counterVectorsFromResource(resource) - if metric == nil || metricSeconds == nil { - continue - } - c, err := metric.GetMetricWithLabelValues(labels[1:]...) // we don't need priorState label here - if err != nil { - return err - } - cSeconds, err := metricSeconds.GetMetricWithLabelValues(labels...) - if err != nil { - return err - } - q := requests[resource] - v := float64(q.MilliValue()) / 1000 - c.Add(v) - cSeconds.Add(v * stateDuration.Seconds()) - } - return nil -} - -// counterVectorsFromResource returns the counter and counterSeconds Vectors for the given resource name. -// If the counter and counterSeconds Vecs do not exist, they are created and stored in the resourceCounters map. -func (m *Metrics) counterVectorsFromResource(resource v1.ResourceName) (*prometheus.CounterVec, *prometheus.CounterVec) { - c, ok := m.resourceCounters[resource] - if !ok { - name := resource.String() + "_total" - c = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: name, - Help: resource.String() + "resource counter.", - }, - []string{"state", "category", "subCategory", "queue", "cluster", "pool"}, - ) - m.resourceCounters[resource] = c - } - - resourceSeconds := v1.ResourceName(resource.String() + "_seconds") - cSeconds, ok := m.resourceCounters[resourceSeconds] - if !ok { - name := resourceSeconds.String() + "_total" - cSeconds = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: name, - Help: resource.String() + "-second resource counter.", - }, - []string{"priorState", "state", "category", "subCategory", "queue", "cluster", "pool"}, - ) - m.resourceCounters[resourceSeconds] = cSeconds - } - return c, cSeconds -} - -// stateDuration returns: -// - the duration of the current state (stateTime - priorTime) -// - the prior state name -func stateDuration(job *jobdb.Job, run *jobdb.JobRun, stateTime *time.Time) (time.Duration, string) { - if stateTime == nil { - return 0, "" - } - - queuedTime := time.Unix(0, job.Created()) - diff := stateTime.Sub(queuedTime).Seconds() - prior := queued - priorTime := &queuedTime - - if run.LeaseTime() != nil { - if sub := stateTime.Sub(*run.LeaseTime()).Seconds(); sub < diff && sub > 0 { - prior = leased - priorTime = run.LeaseTime() - diff = sub - } - } - if run.PendingTime() != nil { - if sub := stateTime.Sub(*run.PendingTime()).Seconds(); sub < diff && sub > 0 { - prior = pending - priorTime = run.PendingTime() - diff = sub - } - } - if run.RunningTime() != nil { - if sub := stateTime.Sub(*run.RunningTime()).Seconds(); sub < diff && sub > 0 { - prior = running - priorTime = run.RunningTime() - } - } - // succeeded, failed, cancelled, preempted are not prior states - - return stateTime.Sub(*priorTime), prior + m.jobStateMetrics.collect(ch) + m.cycleMetrics.collect(ch) } diff --git a/internal/scheduler/metrics/metrics_test.go b/internal/scheduler/metrics/metrics_test.go deleted file mode 100644 index 448bb279f26..00000000000 --- a/internal/scheduler/metrics/metrics_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package metrics - -import ( - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" - v1 "k8s.io/api/core/v1" - - "github.com/armadaproject/armada/internal/common/armadacontext" - "github.com/armadaproject/armada/internal/scheduler/configuration" - "github.com/armadaproject/armada/internal/scheduler/context" - "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" - "github.com/armadaproject/armada/internal/scheduler/testfixtures" - "github.com/armadaproject/armada/pkg/armadaevents" -) - -func TestUpdate(t *testing.T) { - ctx := armadacontext.Background() - - metrics, err := New(configuration.MetricsConfig{ - TrackedErrorRegexes: nil, - TrackedResourceNames: []v1.ResourceName{"cpu"}, - ResetInterval: 24 * time.Hour, - }) - require.NoError(t, err) - - now := time.Now() - - queuedJob := testfixtures.NewJob(uuid.NewString(), - "test-jobset", - "test-queue", - 1, - &schedulerobjects.JobSchedulingInfo{}, - true, - 0, - false, - false, - false, - time.Now().UnixNano(), - true) - - jobRunErrorsByRunId := map[uuid.UUID]*armadaevents.Error{ - uuid.MustParse(queuedJob.Id()): { - Terminal: true, - Reason: &armadaevents.Error_PodError{ - PodError: &armadaevents.PodError{ - Message: "my error", - }, - }, - }, - } - - leasedJob := queuedJob.WithNewRun("test-executor", "node1", "test-node", "test-pool", 1) - pendingJob := leasedJob.WithUpdatedRun(leasedJob.LatestRun().WithPendingTime(addSeconds(now, 1))) - runningJob := pendingJob.WithUpdatedRun(pendingJob.LatestRun().WithRunningTime(addSeconds(now, 2))) - finishedJob := runningJob.WithUpdatedRun(runningJob.LatestRun().WithTerminatedTime(addSeconds(now, 3))) - preemptedJob := finishedJob.WithUpdatedRun(runningJob.LatestRun().WithPreemptedTime(addSeconds(now, 4))) - - require.NoError(t, metrics.UpdateQueued(queuedJob)) - require.NoError(t, metrics.UpdateLeased(context.JobSchedulingContextFromJob(leasedJob))) - require.NoError(t, metrics.UpdatePending(pendingJob)) - require.NoError(t, metrics.UpdateRunning(runningJob)) - require.NoError(t, metrics.UpdateSucceeded(finishedJob)) - require.NoError(t, metrics.UpdateCancelled(finishedJob)) - require.NoError(t, metrics.UpdateFailed(ctx, finishedJob, jobRunErrorsByRunId)) - require.NoError(t, metrics.UpdatePreempted(preemptedJob)) -} - -func addSeconds(t time.Time, seconds int) *time.Time { - t = t.Add(time.Duration(seconds) * time.Second) - return &t -} diff --git a/internal/scheduler/metrics/state_metrics.go b/internal/scheduler/metrics/state_metrics.go new file mode 100644 index 00000000000..ce3894436a7 --- /dev/null +++ b/internal/scheduler/metrics/state_metrics.go @@ -0,0 +1,335 @@ +package metrics + +import ( + "regexp" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + v1 "k8s.io/api/core/v1" + + "github.com/armadaproject/armada/internal/scheduler/jobdb" + "github.com/armadaproject/armada/pkg/armadaevents" +) + +type resettableMetric interface { + prometheus.Collector + Reset() +} + +type jobStateMetrics struct { + errorRegexes []*regexp.Regexp + resetInterval time.Duration + lastResetTime time.Time + enabled bool + trackedResourceNames []v1.ResourceName + + completedRunDurations *prometheus.HistogramVec + jobStateCounterByQueue *prometheus.CounterVec + jobStateCounterByNode *prometheus.CounterVec + jobStateSecondsByQueue *prometheus.CounterVec + jobStateSecondsByNode *prometheus.CounterVec + jobStateResourceSecondsByQueue *prometheus.CounterVec + jobStateResourceSecondsByNode *prometheus.CounterVec + jobErrorsByQueue *prometheus.CounterVec + jobErrorsByNode *prometheus.CounterVec + allMetrics []resettableMetric +} + +func newJobStateMetrics(errorRegexes []*regexp.Regexp, trackedResourceNames []v1.ResourceName, resetInterval time.Duration) *jobStateMetrics { + completedRunDurations := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: prefix + "job_run_completed_duration_seconds", + Help: "Time", + Buckets: prometheus.ExponentialBuckets(2, 2, 20), + }, + []string{queueLabel, poolLabel}, + ) + jobStateCounterByQueue := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_counter_by_queue", + Help: "Job states at queue level", + }, + []string{queueLabel, poolLabel, stateLabel, priorStateLabel}, + ) + jobStateCounterByNode := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_counter_by_node", + Help: "Job states at node level", + }, + []string{nodeLabel, poolLabel, clusterLabel, stateLabel, priorStateLabel}, + ) + jobStateSecondsByQueue := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_seconds_by_queue", + Help: "time spent in different states at the queue level", + }, + []string{queueLabel, poolLabel, stateLabel, priorStateLabel}, + ) + jobStateSecondsByNode := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_seconds_by_node", + Help: "time spent in different states at the node level", + }, + []string{nodeLabel, poolLabel, clusterLabel, stateLabel, priorStateLabel}, + ) + jobStateResourceSecondsByQueue := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_resource_seconds_by_queue", + Help: "Resource-seconds spent in different states at the queue level", + }, + []string{queueLabel, poolLabel, stateLabel, priorStateLabel, resourceLabel}, + ) + jobStateResourceSecondsByNode := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_state_resource_seconds_by_node", + Help: "Resource-seconds spent in different states at the node level", + }, + []string{nodeLabel, poolLabel, clusterLabel, stateLabel, priorStateLabel, resourceLabel}, + ) + jobErrorsByQueue := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "job_error_classification_by_queue", + Help: "Failed jobs by error classification at the queue level", + }, + []string{queueLabel, poolLabel, errorCategoryLabel, errorSubcategoryLabel}, + ) + jobErrorsByNode := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: prefix + "error_classification_by_node", + Help: "Failed jobs ey error classification at the node level", + }, + []string{nodeLabel, poolLabel, clusterLabel, errorCategoryLabel, errorSubcategoryLabel}, + ) + return &jobStateMetrics{ + errorRegexes: errorRegexes, + trackedResourceNames: trackedResourceNames, + resetInterval: resetInterval, + lastResetTime: time.Now(), + enabled: true, + completedRunDurations: completedRunDurations, + jobStateCounterByQueue: jobStateCounterByQueue, + jobStateCounterByNode: jobStateCounterByNode, + jobStateSecondsByQueue: jobStateSecondsByQueue, + jobStateSecondsByNode: jobStateSecondsByNode, + jobStateResourceSecondsByQueue: jobStateResourceSecondsByQueue, + jobStateResourceSecondsByNode: jobStateResourceSecondsByNode, + jobErrorsByQueue: jobErrorsByQueue, + jobErrorsByNode: jobErrorsByNode, + allMetrics: []resettableMetric{ + completedRunDurations, + jobStateCounterByQueue, + jobStateCounterByNode, + jobStateSecondsByQueue, + jobStateSecondsByNode, + jobStateResourceSecondsByQueue, + jobStateResourceSecondsByNode, + jobErrorsByQueue, + jobErrorsByNode, + }, + } +} + +func (m *jobStateMetrics) describe(ch chan<- *prometheus.Desc) { + if m.enabled { + for _, metric := range m.allMetrics { + metric.Describe(ch) + } + } +} + +func (m *jobStateMetrics) collect(ch chan<- prometheus.Metric) { + if m.enabled { + // Reset metrics periodically. + if time.Now().Sub(m.lastResetTime) > m.resetInterval { + m.reset() + } + for _, metric := range m.allMetrics { + metric.Collect(ch) + } + } +} + +// ReportJobLeased reports the job as being leasedJob. This has to be reported separately because the state transition +// logic does work for job leased! +func (m *jobStateMetrics) ReportJobLeased(job *jobdb.Job) { + run := job.LatestRun() + duration, priorState := stateDuration(job, run, run.LeaseTime()) + m.updateStateDuration(job, leased, priorState, duration) +} + +func (m *jobStateMetrics) ReportStateTransitions( + jsts []jobdb.JobStateTransitions, + jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error, +) { + for _, jst := range jsts { + job := jst.Job + run := job.LatestRun() + if jst.Pending { + duration, priorState := stateDuration(job, run, run.PendingTime()) + m.updateStateDuration(job, pending, priorState, duration) + } + if jst.Running { + duration, priorState := stateDuration(job, run, run.RunningTime()) + m.updateStateDuration(job, running, priorState, duration) + } + if jst.Cancelled { + duration, priorState := stateDuration(job, run, run.TerminatedTime()) + m.updateStateDuration(job, cancelled, priorState, duration) + if job.LatestRun() != nil { + m.completedRunDurations.WithLabelValues(job.Queue(), run.Pool()).Observe(duration) + } + } + if jst.Failed { + duration, priorState := stateDuration(job, run, run.TerminatedTime()) + m.updateStateDuration(job, failed, priorState, duration) + m.completedRunDurations.WithLabelValues(job.Queue(), run.Pool()).Observe(duration) + jobRunError := jobRunErrorsByRunId[run.Id()] + category, subCategory := m.failedCategoryAndSubCategoryFromJob(jobRunError) + m.jobErrorsByQueue.WithLabelValues(job.Queue(), run.Executor(), category, subCategory).Inc() + } + if jst.Succeeded { + duration, priorState := stateDuration(job, run, run.TerminatedTime()) + m.updateStateDuration(job, succeeded, priorState, duration) + m.completedRunDurations.WithLabelValues(job.Queue(), run.Pool()).Observe(duration) + } + if jst.Preempted { + duration, priorState := stateDuration(job, run, run.PreemptedTime()) + m.updateStateDuration(job, preempted, priorState, duration) + m.completedRunDurations.WithLabelValues(job.Queue(), run.Pool()).Observe(duration) + } + } +} + +func (m *jobStateMetrics) updateStateDuration(job *jobdb.Job, state string, priorState string, duration float64) { + if duration <= 0 { + return + } + + queue := job.Queue() + requests := job.ResourceRequirements().Requests + latestRun := job.LatestRun() + pool := "" + node := "" + cluster := "" + if latestRun != nil { + pool = latestRun.Pool() + node = latestRun.NodeName() + cluster = latestRun.Executor() + } + + // Counters + m.jobStateCounterByQueue. + WithLabelValues(queue, pool, state, priorState).Inc() + + m.jobStateCounterByNode. + WithLabelValues(node, pool, cluster, state, priorState).Inc() + + // State seconds + m.jobStateSecondsByQueue. + WithLabelValues(queue, pool, state, priorState).Add(duration) + + m.jobStateSecondsByNode. + WithLabelValues(node, pool, cluster, state, priorState).Add(duration) + + // Resource Seconds + for _, res := range m.trackedResourceNames { + resQty := requests[res] + resSeconds := duration * float64(resQty.MilliValue()) / 1000 + m.jobStateResourceSecondsByQueue. + WithLabelValues(queue, pool, state, priorState, res.String()).Add(resSeconds) + m.jobStateResourceSecondsByNode. + WithLabelValues(node, pool, cluster, state, priorState, res.String()).Add(resSeconds) + } +} + +func (m *jobStateMetrics) failedCategoryAndSubCategoryFromJob(err *armadaevents.Error) (string, string) { + category, message := errorTypeAndMessageFromError(err) + for _, r := range m.errorRegexes { + if r.MatchString(message) { + return category, r.String() + } + } + return category, "" +} + +func (m *jobStateMetrics) reset() { + m.jobStateCounterByNode.Reset() + for _, metric := range m.allMetrics { + metric.Reset() + } + m.lastResetTime = time.Now() +} + +func (m *jobStateMetrics) disable() { + m.reset() + m.enabled = false +} + +func (m *jobStateMetrics) enable() { + m.enabled = true +} + +// isEnabled returns true if job state metrics are enabled +func (m *jobStateMetrics) isEnabled() bool { + return m.enabled +} + +// stateDuration returns: +// - the duration of the current state (stateTime - priorTime) +// - the prior state name +func stateDuration(job *jobdb.Job, run *jobdb.JobRun, stateTime *time.Time) (float64, string) { + if stateTime == nil { + return 0, "" + } + + queuedTime := time.Unix(0, job.Created()) + diff := stateTime.Sub(queuedTime).Seconds() + prior := queued + priorTime := &queuedTime + + if run.LeaseTime() != nil { + if sub := stateTime.Sub(*run.LeaseTime()).Seconds(); sub < diff && sub > 0 { + prior = leased + priorTime = run.LeaseTime() + diff = sub + } + } + if run.PendingTime() != nil { + if sub := stateTime.Sub(*run.PendingTime()).Seconds(); sub < diff && sub > 0 { + prior = pending + priorTime = run.PendingTime() + diff = sub + } + } + if run.RunningTime() != nil { + if sub := stateTime.Sub(*run.RunningTime()).Seconds(); sub < diff && sub > 0 { + prior = running + priorTime = run.RunningTime() + } + } + // succeeded, failed, cancelled, preempted are not prior states + return stateTime.Sub(*priorTime).Seconds(), prior +} + +func errorTypeAndMessageFromError(err *armadaevents.Error) (string, string) { + if err == nil { + return "", "" + } + // The following errors relate to job run failures. + // We do not process JobRunPreemptedError as there is separate metric for preemption. + switch reason := err.Reason.(type) { + case *armadaevents.Error_PodUnschedulable: + return "podUnschedulable", reason.PodUnschedulable.Message + case *armadaevents.Error_LeaseExpired: + return "leaseExpired", "" + case *armadaevents.Error_PodError: + return "podError", reason.PodError.Message + case *armadaevents.Error_PodLeaseReturned: + return "podLeaseReturned", reason.PodLeaseReturned.Message + case *armadaevents.Error_JobRunPreemptedError: + return "jobRunPreempted", "" + default: + return "", "" + } +} diff --git a/internal/scheduler/metrics/state_metrics_test.go b/internal/scheduler/metrics/state_metrics_test.go new file mode 100644 index 00000000000..0203cdcee2f --- /dev/null +++ b/internal/scheduler/metrics/state_metrics_test.go @@ -0,0 +1,409 @@ +package metrics + +import ( + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + + "github.com/armadaproject/armada/internal/scheduler/jobdb" + "github.com/armadaproject/armada/internal/scheduler/testfixtures" + "github.com/armadaproject/armada/pkg/armadaevents" +) + +const ( + testPool = "testPool" + testNode = "testNode" + testCluster = "testCluster" + testQueue = testfixtures.TestQueue + testPriorityClass = testfixtures.PriorityClass0 +) + +var ( + baseTime = time.Now() + + baseRun = jobdb.MinimalRun(uuid.New(), baseTime.UnixNano()). + WithPool(testPool).WithNodeName(testNode). + WithExecutor(testCluster) + + baseJob = testfixtures.Test16Cpu128GiJob(testQueue, testPriorityClass). + WithSubmittedTime(baseTime.UnixNano()) +) + +func TestReportJobStateTransitions(t *testing.T) { + baseTimePlusSeconds := func(numSeconds int) *time.Time { + newTime := baseTime.Add(time.Duration(numSeconds) * time.Second) + return &newTime + } + + tests := map[string]struct { + errorRegexes []*regexp.Regexp + trackedResourceNames []v1.ResourceName + jsts []jobdb.JobStateTransitions + jobRunErrorsByRunId map[uuid.UUID]*armadaevents.Error + expectedJobStateCounterByQueue map[[4]string]float64 + expectedJobStateCounterByNode map[[5]string]float64 + expectedJobStateSecondsByQueue map[[4]string]float64 + expectedJobStateSecondsByNode map[[5]string]float64 + expectedJobStateResourceSecondsByQueue map[[5]string]float64 + expectedJobStateResourceSecondsByNode map[[6]string]float64 + }{ + "Pending": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62))), + Pending: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "pending", "leased"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "pending", "leased"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "pending", "leased"}: 2, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "pending", "leased"}: 2, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "pending", "leased", "cpu"}: 2 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "pending", "leased", "cpu"}: 2 * 16, + }, + }, + "Running": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62)). + WithRunningTime(baseTimePlusSeconds(72))), + Running: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "running", "pending"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "running", "pending"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "running", "pending"}: 10, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "running", "pending"}: 10, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "running", "pending", "cpu"}: 10 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "running", "pending", "cpu"}: 10 * 16, + }, + }, + "Succeeded": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62)). + WithRunningTime(baseTimePlusSeconds(72)). + WithTerminatedTime(baseTimePlusSeconds(80))), + Succeeded: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "succeeded", "running"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "succeeded", "running"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "succeeded", "running"}: 8, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "succeeded", "running"}: 8, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "succeeded", "running", "cpu"}: 8 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "succeeded", "running", "cpu"}: 8 * 16, + }, + }, + "Cancelled": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62)). + WithRunningTime(baseTimePlusSeconds(72)). + WithTerminatedTime(baseTimePlusSeconds(80))), + Cancelled: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "cancelled", "running"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "cancelled", "running"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "cancelled", "running"}: 8, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "cancelled", "running"}: 8, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "cancelled", "running", "cpu"}: 8 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "cancelled", "running", "cpu"}: 8 * 16, + }, + }, + "Failed": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62)). + WithRunningTime(baseTimePlusSeconds(72)). + WithTerminatedTime(baseTimePlusSeconds(80))), + Failed: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "failed", "running"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "failed", "running"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "failed", "running"}: 8, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "failed", "running"}: 8, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "failed", "running", "cpu"}: 8 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "failed", "running", "cpu"}: 8 * 16, + }, + }, + "Preempted": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(60)). + WithPendingTime(baseTimePlusSeconds(62)). + WithRunningTime(baseTimePlusSeconds(72)). + WithPreemptedTime(baseTimePlusSeconds(80))), + Preempted: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "preempted", "running"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "preempted", "running"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "preempted", "running"}: 8, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "preempted", "running"}: 8, + }, + expectedJobStateResourceSecondsByQueue: map[[5]string]float64{ + {testQueue, testPool, "preempted", "running", "cpu"}: 8 * 16, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "preempted", "running", "cpu"}: 8 * 16, + }, + }, + "Multiple transitions": { + trackedResourceNames: []v1.ResourceName{"cpu"}, + jsts: []jobdb.JobStateTransitions{ + { + Job: baseJob. + WithUpdatedRun( + baseRun. + WithLeasedTime(baseTimePlusSeconds(1)). + WithPendingTime(baseTimePlusSeconds(3)). + WithRunningTime(baseTimePlusSeconds(6)). + WithTerminatedTime(baseTimePlusSeconds(10))), + Leased: true, + Pending: true, + Running: true, + Succeeded: true, + }, + }, + expectedJobStateCounterByQueue: map[[4]string]float64{ + {testQueue, testPool, "pending", "leased"}: 1, + {testQueue, testPool, "running", "pending"}: 1, + {testQueue, testPool, "succeeded", "running"}: 1, + }, + expectedJobStateCounterByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "pending", "leased"}: 1, + {testNode, testPool, testCluster, "running", "pending"}: 1, + {testNode, testPool, testCluster, "succeeded", "running"}: 1, + }, + expectedJobStateSecondsByQueue: map[[4]string]float64{ + {testQueue, testPool, "pending", "leased"}: 2, + {testQueue, testPool, "running", "pending"}: 3, + {testQueue, testPool, "succeeded", "running"}: 4, + }, + expectedJobStateSecondsByNode: map[[5]string]float64{ + {testNode, testPool, testCluster, "pending", "leased"}: 2, + {testNode, testPool, testCluster, "running", "pending"}: 3, + {testNode, testPool, testCluster, "succeeded", "running"}: 4, + }, + expectedJobStateResourceSecondsByNode: map[[6]string]float64{ + {testNode, testPool, testCluster, "pending", "leased", "cpu"}: 32, + {testNode, testPool, testCluster, "running", "pending", "cpu"}: 48, + {testNode, testPool, testCluster, "succeeded", "running", "cpu"}: 64, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + metrics := newJobStateMetrics(tc.errorRegexes, tc.trackedResourceNames, 12*time.Hour) + metrics.ReportStateTransitions(tc.jsts, tc.jobRunErrorsByRunId) + + // jobStateCounterByQueue + for k, v := range tc.expectedJobStateCounterByQueue { + actualCounter := testutil.ToFloat64(metrics.jobStateCounterByQueue.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualCounter, epsilon, "jobStateCounterByQueue for %s", strings.Join(k[:], ",")) + } + + // jobStateCounterByQueue + for k, v := range tc.expectedJobStateCounterByNode { + actualCounter := testutil.ToFloat64(metrics.jobStateCounterByNode.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualCounter, epsilon, "jobStateCounterByNode for %s", strings.Join(k[:], ",")) + } + + // jobStateSecondsByNode + for k, v := range tc.expectedJobStateSecondsByNode { + actualJobStateSeconds := testutil.ToFloat64(metrics.jobStateSecondsByNode.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualJobStateSeconds, epsilon, "jobStateSecondsByNode for %s", strings.Join(k[:], ",")) + } + + // jobStateSecondsByQueue + for k, v := range tc.expectedJobStateSecondsByQueue { + actualJobStateSeconds := testutil.ToFloat64(metrics.jobStateSecondsByQueue.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualJobStateSeconds, epsilon, "jobStateSecondsByQueue for %s", strings.Join(k[:], ",")) + } + + // jobStateSecondsByNode + for k, v := range tc.expectedJobStateSecondsByNode { + actualJobStateSeconds := testutil.ToFloat64(metrics.jobStateSecondsByNode.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualJobStateSeconds, epsilon, "jobStateSecondsByNode for %s", strings.Join(k[:], ",")) + } + + // jobStateResourceSecondsByQueue + for k, v := range tc.expectedJobStateResourceSecondsByQueue { + actualJobStateSeconds := testutil.ToFloat64(metrics.jobStateResourceSecondsByQueue.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualJobStateSeconds, epsilon, "jobStateResourceSecondsByQueue for %s", strings.Join(k[:], ",")) + } + + // jobStateResourceSecondsByNode + for k, v := range tc.expectedJobStateResourceSecondsByNode { + actualJobStateSeconds := testutil.ToFloat64(metrics.jobStateResourceSecondsByNode.WithLabelValues(k[:]...)) + assert.InDelta(t, v, actualJobStateSeconds, epsilon, "jobStateResourceSecondsByNode for %s", strings.Join(k[:], ",")) + } + }) + } +} + +func TestReset(t *testing.T) { + byQueueLabels := []string{testQueue, testPool, "running", "pending"} + byNodeLabels := []string{testNode, testPool, testCluster, "running", "pending"} + byQueueResourceLabels := append(byQueueLabels, "cpu") + byNodeResourceLabels := append(byNodeLabels, "cpu") + m := newJobStateMetrics(nil, nil, 12*time.Hour) + + testReset := func(vec *prometheus.CounterVec, labels []string) { + vec.WithLabelValues(labels...).Inc() + counterVal := testutil.ToFloat64(vec.WithLabelValues(labels...)) + assert.Equal(t, 1.0, counterVal) + m.reset() + counterVal = testutil.ToFloat64(vec.WithLabelValues(labels...)) + assert.Equal(t, 0.0, counterVal) + } + + testReset(m.jobStateCounterByQueue, byQueueLabels) + testReset(m.jobStateSecondsByNode, byNodeLabels) + testReset(m.jobStateSecondsByQueue, byQueueLabels) + testReset(m.jobStateSecondsByNode, byNodeLabels) + testReset(m.jobStateResourceSecondsByQueue, byQueueResourceLabels) + testReset(m.jobStateResourceSecondsByNode, byNodeResourceLabels) + testReset(m.jobErrorsByQueue, byQueueLabels) + testReset(m.jobErrorsByNode, byNodeLabels) +} + +func TestDisable(t *testing.T) { + byQueueLabels := []string{testQueue, testPool, "running", "pending"} + byNodeLabels := []string{testNode, testPool, testCluster, "running", "pending"} + byQueueResourceLabels := append(byQueueLabels, "cpu") + byNodeResourceLabels := append(byNodeLabels, "cpu") + + collect := func(m *jobStateMetrics) []prometheus.Metric { + m.jobStateCounterByQueue.WithLabelValues(byQueueLabels...).Inc() + m.jobStateSecondsByNode.WithLabelValues(byNodeLabels...).Inc() + m.jobStateSecondsByQueue.WithLabelValues(byQueueLabels...).Inc() + m.jobStateSecondsByNode.WithLabelValues(byNodeLabels...).Inc() + m.jobStateResourceSecondsByQueue.WithLabelValues(byQueueResourceLabels...).Inc() + m.jobStateResourceSecondsByNode.WithLabelValues(byNodeResourceLabels...).Inc() + m.jobErrorsByQueue.WithLabelValues(byQueueLabels...).Inc() + m.jobErrorsByNode.WithLabelValues(byNodeLabels...).Inc() + + ch := make(chan prometheus.Metric, 1000) + m.collect(ch) + collected := make([]prometheus.Metric, 0, len(ch)) + for len(ch) > 0 { + collected = append(collected, <-ch) + } + return collected + } + + m := newJobStateMetrics(nil, nil, 12*time.Hour) + + // Enabled + assert.NotZero(t, len(collect(m))) + + // Disabled + m.disable() + assert.Zero(t, len(collect(m))) + + // Enabled + m.enable() + assert.NotZero(t, len(collect(m))) +} diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index 67495c5f176..105f87a746e 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -22,6 +22,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/nodedb" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" ) // PreemptingQueueScheduler is a scheduler that makes a unified decisions on which jobs to preempt and schedule. @@ -95,7 +96,7 @@ func (sch *PreemptingQueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { // Schedule // - preempts jobs belonging to queues with total allocation above their fair share and // - schedules new jobs belonging to queues with total allocation less than their fair share. -func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { +func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*schedulerresult.SchedulerResult, error) { defer func() { sch.schedulingContext.Finished = time.Now() }() @@ -253,7 +254,7 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*Sche } ctx.WithField("stage", "scheduling-algo").Infof("Finished running assertions after scheduling round") } - return &SchedulerResult{ + return &schedulerresult.SchedulerResult{ PreemptedJobs: preemptedJobs, ScheduledJobs: scheduledJobs, NodeIdByJobId: sch.nodeIdByJobId, @@ -524,7 +525,7 @@ func addEvictedJobsToNodeDb(_ *armadacontext.Context, sctx *schedulercontext.Sch return nil } -func (sch *PreemptingQueueScheduler) schedule(ctx *armadacontext.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*SchedulerResult, error) { +func (sch *PreemptingQueueScheduler) schedule(ctx *armadacontext.Context, inMemoryJobRepo *InMemoryJobRepository, jobRepo JobRepository) (*schedulerresult.SchedulerResult, error) { jobIteratorByQueue := make(map[string]JobIterator) for _, qctx := range sch.schedulingContext.QueueSchedulingContexts { evictedIt := inMemoryJobRepo.GetJobIterator(qctx.Queue) diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index dde7f7c5144..849b3262440 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -28,6 +28,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/nodedb" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" "github.com/armadaproject/armada/internal/scheduler/testfixtures" ) @@ -2233,7 +2234,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { require.NoError(b, err) jobsByNodeId := make(map[string][]*jobdb.Job) - for _, job := range ScheduledJobsFromSchedulerResult(result) { + for _, job := range schedulerresult.ScheduledJobsFromSchedulerResult(result) { nodeId := result.NodeIdByJobId[job.Id()] jobsByNodeId[nodeId] = append(jobsByNodeId[nodeId], job) } diff --git a/internal/scheduler/queue_scheduler.go b/internal/scheduler/queue_scheduler.go index c5bfe1c14af..b2e9d2f8916 100644 --- a/internal/scheduler/queue_scheduler.go +++ b/internal/scheduler/queue_scheduler.go @@ -16,6 +16,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/floatingresources" "github.com/armadaproject/armada/internal/scheduler/nodedb" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" ) // QueueScheduler is responsible for choosing the order in which to attempt scheduling queued gangs. @@ -61,7 +62,7 @@ func (sch *QueueScheduler) SkipUnsuccessfulSchedulingKeyCheck() { sch.gangScheduler.SkipUnsuccessfulSchedulingKeyCheck() } -func (sch *QueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResult, error) { +func (sch *QueueScheduler) Schedule(ctx *armadacontext.Context) (*schedulerresult.SchedulerResult, error) { var scheduledJobs []*schedulercontext.JobSchedulingContext nodeIdByJobId := make(map[string]string) @@ -205,7 +206,7 @@ func (sch *QueueScheduler) Schedule(ctx *armadacontext.Context) (*SchedulerResul if len(scheduledJobs) != len(nodeIdByJobId) { return nil, errors.Errorf("only %d out of %d jobs mapped to a node", len(nodeIdByJobId), len(scheduledJobs)) } - return &SchedulerResult{ + return &schedulerresult.SchedulerResult{ PreemptedJobs: nil, ScheduledJobs: scheduledJobs, NodeIdByJobId: nodeIdByJobId, diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index e98eb44d439..b51004bbbcf 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -4,9 +4,8 @@ import ( "fmt" "time" - "github.com/gogo/protobuf/types" - "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" "github.com/google/uuid" "github.com/pkg/errors" "github.com/renstrom/shortuuid" @@ -23,6 +22,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/leader" "github.com/armadaproject/armada/internal/scheduler/metrics" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" "github.com/armadaproject/armada/internal/server/configuration" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -74,10 +74,8 @@ type Scheduler struct { runsSerial int64 // Function that is called every time a cycle is completed. Useful for testing. onCycleCompleted func() - // metrics set for the scheduler. - metrics *SchedulerMetrics - // New scheduler metrics due to replace the above. - schedulerMetrics *metrics.Metrics + // Prometheus metrics which report the state of the scheduler + metrics *metrics.Metrics // If true, enable scheduler assertions. // In particular, assert that the jobDb is in a valid state at the end of each cycle. enableAssertions bool @@ -96,8 +94,7 @@ func NewScheduler( executorTimeout time.Duration, maxAttemptedRuns uint, nodeIdLabel string, - metrics *SchedulerMetrics, - schedulerMetrics *metrics.Metrics, + metrics *metrics.Metrics, ) (*Scheduler, error) { return &Scheduler{ jobRepository: jobRepository, @@ -117,7 +114,6 @@ func NewScheduler( jobsSerial: -1, runsSerial: -1, metrics: metrics, - schedulerMetrics: schedulerMetrics, }, nil } @@ -228,9 +224,9 @@ func (s *Scheduler) Run(ctx *armadacontext.Context) error { // This means we can start the next cycle immediately after one cycle finishes. // As state transitions are persisted and read back from the schedulerDb over later cycles, // there is no change to the jobDb, since the correct changes have already been made. -func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToken leader.LeaderToken, shouldSchedule bool) (SchedulerResult, error) { +func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToken leader.LeaderToken, shouldSchedule bool) (schedulerresult.SchedulerResult, error) { // TODO: Consider returning a slice of these instead. - overallSchedulerResult := SchedulerResult{} + overallSchedulerResult := schedulerresult.SchedulerResult{} // Update job state. ctx.Info("Syncing internal state with database") @@ -244,10 +240,10 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke // Only export metrics if leader. if !s.leaderController.ValidateToken(leaderToken) { ctx.Info("Not the leader so will not attempt to schedule") - s.schedulerMetrics.Disable() + s.metrics.DisableJobStateMetrics() return overallSchedulerResult, err } else { - s.schedulerMetrics.Enable() + s.metrics.EnableJobStateMetrics() } // If we've been asked to generate messages for all jobs, do so. @@ -279,10 +275,8 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke ctx.Infof("Fetched %d job run errors", len(jobRepoRunErrorsByRunId)) // Update metrics. - if !s.schedulerMetrics.IsDisabled() { - if err := s.schedulerMetrics.UpdateMany(ctx, jsts, jobRepoRunErrorsByRunId); err != nil { - return overallSchedulerResult, err - } + if !s.metrics.JobStateMetricsEnabled() { + s.metrics.ReportStateTransitions(jsts, jobRepoRunErrorsByRunId) } // Generate any eventSequences that came out of synchronising the db state. @@ -311,7 +305,7 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke // Schedule jobs. if shouldSchedule { - var result *SchedulerResult + var result *schedulerresult.SchedulerResult result, err = s.schedulingAlgo.Schedule(ctx, txn) if err != nil { return overallSchedulerResult, err @@ -350,28 +344,15 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke txn.Commit() ctx.Info("Completed committing cycle transaction") - // Update metrics based on overallSchedulerResult. - if err := s.updateMetricsFromSchedulerResult(ctx, overallSchedulerResult); err != nil { - return overallSchedulerResult, err + if s.metrics.JobStateMetricsEnabled() { + for _, jctx := range overallSchedulerResult.ScheduledJobs { + s.metrics.ReportJobLeased(jctx.Job) + } } return overallSchedulerResult, nil } -func (s *Scheduler) updateMetricsFromSchedulerResult(ctx *armadacontext.Context, overallSchedulerResult SchedulerResult) error { - if s.schedulerMetrics.IsDisabled() { - return nil - } - for _, jctx := range overallSchedulerResult.ScheduledJobs { - if err := s.schedulerMetrics.UpdateLeased(jctx); err != nil { - return err - } - } - // UpdatePreempted is called from within UpdateFailed if the job has a JobRunPreemptedError. - // This is to make sure that preemption is counted only when the job is actually preempted, not when the scheduler decides to preempt it. - return nil -} - // syncState updates jobs in jobDb to match state in postgres and returns all updated jobs. func (s *Scheduler) syncState(ctx *armadacontext.Context) ([]*jobdb.Job, []jobdb.JobStateTransitions, error) { txn := s.jobDb.WriteTxn() @@ -474,14 +455,14 @@ func (s *Scheduler) addNodeAntiAffinitiesForAttemptedRunsIfSchedulable(ctx *arma } // eventsFromSchedulerResult generates necessary EventSequences from the provided SchedulerResult. -func (s *Scheduler) eventsFromSchedulerResult(result *SchedulerResult) ([]*armadaevents.EventSequence, error) { +func (s *Scheduler) eventsFromSchedulerResult(result *schedulerresult.SchedulerResult) ([]*armadaevents.EventSequence, error) { return EventsFromSchedulerResult(result, s.clock.Now()) } // EventsFromSchedulerResult generates necessary EventSequences from the provided SchedulerResult. -func EventsFromSchedulerResult(result *SchedulerResult, time time.Time) ([]*armadaevents.EventSequence, error) { +func EventsFromSchedulerResult(result *schedulerresult.SchedulerResult, time time.Time) ([]*armadaevents.EventSequence, error) { eventSequences := make([]*armadaevents.EventSequence, 0, len(result.PreemptedJobs)+len(result.ScheduledJobs)) - eventSequences, err := AppendEventSequencesFromPreemptedJobs(eventSequences, PreemptedJobsFromSchedulerResult(result), time) + eventSequences, err := AppendEventSequencesFromPreemptedJobs(eventSequences, schedulerresult.PreemptedJobsFromSchedulerResult(result), time) if err != nil { return nil, err } diff --git a/internal/scheduler/scheduler_metrics.go b/internal/scheduler/scheduler_metrics.go deleted file mode 100644 index cfa7543c87a..00000000000 --- a/internal/scheduler/scheduler_metrics.go +++ /dev/null @@ -1,290 +0,0 @@ -package scheduler - -import ( - "fmt" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/armadaproject/armada/internal/scheduler/configuration" - schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" -) - -const ( - NAMESPACE = "armada" - SUBSYSTEM = "scheduler" -) - -type SchedulerMetrics struct { - // Cycle time when scheduling, as leader. - scheduleCycleTime prometheus.Histogram - // Cycle time when reconciling, as leader or follower. - reconcileCycleTime prometheus.Histogram - - mostRecentSchedulingRoundData schedulingRoundData -} - -var scheduledJobsDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "scheduled_jobs"), - "Number of jobs scheduled each round.", - []string{ - "queue", - "priority_class", - }, nil, -) - -var preemptedJobsDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "preempted_jobs"), - "Number of jobs preempted each round.", - []string{ - "queue", - "priority_class", - }, nil, -) - -var consideredJobsDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "considered_jobs"), - "Number of jobs considered in the most recent round per queue and pool.", - []string{ - "queue", - "pool", - }, nil, -) - -var fairSharePerQueueDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "fair_share"), - "Fair share of each queue and pool.", - []string{ - "queue", - "pool", - }, nil, -) - -var adjustedFairSharePerQueueDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "adjusted_fair_share"), - "Adjusted Fair share of each queue and pool.", - []string{ - "queue", - "pool", - }, nil, -) - -var actualSharePerQueueDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "actual_share"), - "Actual share of each queue and pool.", - []string{ - "queue", - "pool", - }, nil, -) - -var demandPerQueueDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "demand"), - "Demand of each queue and pool.", - []string{ - "queue", - "pool", - }, nil, -) - -var cappedDemandPerQueueDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "capped_demand"), - "Capped Demand of each queue and pool. This differs from demand in that it limits demand by scheduling constraints", - []string{ - "queue", - "pool", - }, nil, -) - -var fairnessErrorDesc = prometheus.NewDesc( - fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "fairness_error"), - "Cumulative delta between adjusted fair share and actual share for all users who are below their fair share", - []string{ - "pool", - }, nil, -) - -func NewSchedulerMetrics(config configuration.SchedulerMetricsConfig) *SchedulerMetrics { - scheduleCycleTime := prometheus.NewHistogram( - prometheus.HistogramOpts{ - Namespace: NAMESPACE, - Subsystem: SUBSYSTEM, - Name: "schedule_cycle_times", - Help: "Cycle time when in a scheduling round.", - Buckets: prometheus.ExponentialBuckets( - config.ScheduleCycleTimeHistogramSettings.Start, - config.ScheduleCycleTimeHistogramSettings.Factor, - config.ScheduleCycleTimeHistogramSettings.Count), - }, - ) - - reconcileCycleTime := prometheus.NewHistogram( - prometheus.HistogramOpts{ - Namespace: NAMESPACE, - Subsystem: SUBSYSTEM, - Name: "reconcile_cycle_times", - Help: "Cycle time when outside of a scheduling round.", - Buckets: prometheus.ExponentialBuckets( - config.ReconcileCycleTimeHistogramSettings.Start, - config.ReconcileCycleTimeHistogramSettings.Factor, - config.ReconcileCycleTimeHistogramSettings.Count), - }, - ) - - prometheus.MustRegister(scheduleCycleTime) - prometheus.MustRegister(reconcileCycleTime) - - return &SchedulerMetrics{ - scheduleCycleTime: scheduleCycleTime, - reconcileCycleTime: reconcileCycleTime, - } -} - -func (m *SchedulerMetrics) ReportScheduleCycleTime(cycleTime time.Duration) { - m.scheduleCycleTime.Observe(float64(cycleTime.Milliseconds())) -} - -func (m *SchedulerMetrics) ReportReconcileCycleTime(cycleTime time.Duration) { - m.reconcileCycleTime.Observe(float64(cycleTime.Milliseconds())) -} - -func (m *SchedulerMetrics) ReportSchedulerResult(result SchedulerResult) { - qpd := m.calculateQueuePoolMetrics(result.SchedulingContexts) - currentSchedulingMetrics := schedulingRoundData{ - queuePoolData: qpd, - scheduledJobData: aggregateJobContexts(m.mostRecentSchedulingRoundData.scheduledJobData, result.ScheduledJobs), - preemptedJobData: aggregateJobContexts(m.mostRecentSchedulingRoundData.preemptedJobData, result.PreemptedJobs), - fairnessError: calculateFairnessError(qpd), - } - - m.mostRecentSchedulingRoundData = currentSchedulingMetrics -} - -func (m *SchedulerMetrics) Describe(desc chan<- *prometheus.Desc) { - desc <- scheduledJobsDesc - desc <- preemptedJobsDesc - desc <- consideredJobsDesc - desc <- fairSharePerQueueDesc - desc <- actualSharePerQueueDesc -} - -func (m *SchedulerMetrics) Collect(metrics chan<- prometheus.Metric) { - schedulingRoundData := m.mostRecentSchedulingRoundData - - schedulingRoundMetrics := generateSchedulerMetrics(schedulingRoundData) - - for _, m := range schedulingRoundMetrics { - metrics <- m - } -} - -func generateSchedulerMetrics(schedulingRoundData schedulingRoundData) []prometheus.Metric { - result := []prometheus.Metric{} - - for key, value := range schedulingRoundData.queuePoolData { - result = append(result, prometheus.MustNewConstMetric(consideredJobsDesc, prometheus.GaugeValue, float64(value.numberOfJobsConsidered), key.queue, key.pool)) - result = append(result, prometheus.MustNewConstMetric(fairSharePerQueueDesc, prometheus.GaugeValue, value.fairShare, key.queue, key.pool)) - result = append(result, prometheus.MustNewConstMetric(adjustedFairSharePerQueueDesc, prometheus.GaugeValue, value.adjustedFairShare, key.queue, key.pool)) - result = append(result, prometheus.MustNewConstMetric(actualSharePerQueueDesc, prometheus.GaugeValue, value.actualShare, key.queue, key.pool)) - result = append(result, prometheus.MustNewConstMetric(demandPerQueueDesc, prometheus.GaugeValue, value.demand, key.queue, key.pool)) - result = append(result, prometheus.MustNewConstMetric(cappedDemandPerQueueDesc, prometheus.GaugeValue, value.cappedDemand, key.queue, key.pool)) - } - for key, value := range schedulingRoundData.scheduledJobData { - result = append(result, prometheus.MustNewConstMetric(scheduledJobsDesc, prometheus.CounterValue, float64(value), key.queue, key.priorityClass)) - } - for key, value := range schedulingRoundData.preemptedJobData { - result = append(result, prometheus.MustNewConstMetric(preemptedJobsDesc, prometheus.CounterValue, float64(value), key.queue, key.priorityClass)) - } - - for pool, fairnessError := range schedulingRoundData.fairnessError { - result = append(result, prometheus.MustNewConstMetric(fairnessErrorDesc, prometheus.GaugeValue, fairnessError, pool)) - } - - return result -} - -// aggregateJobContexts takes a list of jobs and counts how many there are of each queue, priorityClass pair. -func aggregateJobContexts(previousSchedulingRoundData map[queuePriorityClassKey]int, jctxs []*schedulercontext.JobSchedulingContext) map[queuePriorityClassKey]int { - result := make(map[queuePriorityClassKey]int) - - for _, jctx := range jctxs { - job := jctx.Job - key := queuePriorityClassKey{queue: job.Queue(), priorityClass: job.PriorityClassName()} - result[key] += 1 - } - - for key, value := range previousSchedulingRoundData { - _, present := result[key] - if present { - result[key] += value - } else { - result[key] = value - } - } - - return result -} - -func (metrics *SchedulerMetrics) calculateQueuePoolMetrics(schedulingContexts []*schedulercontext.SchedulingContext) map[queuePoolKey]queuePoolData { - result := make(map[queuePoolKey]queuePoolData) - for _, schedContext := range schedulingContexts { - pool := schedContext.Pool - - for queue, queueContext := range schedContext.QueueSchedulingContexts { - key := queuePoolKey{queue: queue, pool: pool} - actualShare := schedContext.FairnessCostProvider.UnweightedCostFromQueue(queueContext) - demand := schedContext.FairnessCostProvider.UnweightedCostFromAllocation(queueContext.Demand) - cappedDemand := schedContext.FairnessCostProvider.UnweightedCostFromAllocation(queueContext.CappedDemand) - result[key] = queuePoolData{ - numberOfJobsConsidered: len(queueContext.UnsuccessfulJobSchedulingContexts) + len(queueContext.SuccessfulJobSchedulingContexts), - fairShare: queueContext.FairShare, - adjustedFairShare: queueContext.AdjustedFairShare, - actualShare: actualShare, - demand: demand, - cappedDemand: cappedDemand, - } - } - } - - return result -} - -// calculateFairnessError returns the cumulative delta between adjusted fair share and actual share for all users who -// are below their fair share -func calculateFairnessError(data map[queuePoolKey]queuePoolData) map[string]float64 { - errors := map[string]float64{} - for k, v := range data { - pool := k.pool - delta := v.adjustedFairShare - v.actualShare - if delta > 0 { - errors[pool] += delta - } - } - return errors -} - -type schedulingRoundData struct { - fairnessError map[string]float64 - queuePoolData map[queuePoolKey]queuePoolData - scheduledJobData map[queuePriorityClassKey]int - preemptedJobData map[queuePriorityClassKey]int -} - -type queuePriorityClassKey struct { - queue string - priorityClass string -} - -type queuePoolKey struct { - queue string - pool string -} - -type queuePoolData struct { - numberOfJobsConsidered int - actualShare float64 - fairShare float64 - adjustedFairShare float64 - demand float64 - cappedDemand float64 -} diff --git a/internal/scheduler/scheduler_metrics_test.go b/internal/scheduler/scheduler_metrics_test.go deleted file mode 100644 index f1ae54306e4..00000000000 --- a/internal/scheduler/scheduler_metrics_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package scheduler - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" - "github.com/armadaproject/armada/internal/scheduler/jobdb" - "github.com/armadaproject/armada/internal/scheduler/testfixtures" -) - -func TestAggregateJobs(t *testing.T) { - testJobs := []*jobdb.Job{ - testfixtures.Test1Cpu4GiJob("queue_a", testfixtures.PriorityClass0), - testfixtures.Test1Cpu4GiJob("queue_b", testfixtures.PriorityClass0), - testfixtures.Test1Cpu4GiJob("queue_a", testfixtures.PriorityClass0), - testfixtures.Test1Cpu4GiJob("queue_a", testfixtures.PriorityClass1), - testfixtures.Test1Cpu4GiJob("queue_a", testfixtures.PriorityClass0), - testfixtures.Test1Cpu4GiJob("queue_b", testfixtures.PriorityClass1), - testfixtures.Test1Cpu4GiJob("queue_a", testfixtures.PriorityClass0), - } - - actual := aggregateJobContexts(map[queuePriorityClassKey]int{}, schedulercontext.JobSchedulingContextsFromJobs(testJobs)) - - expected := map[queuePriorityClassKey]int{ - {queue: "queue_a", priorityClass: testfixtures.PriorityClass0}: 4, - {queue: "queue_a", priorityClass: testfixtures.PriorityClass1}: 1, - {queue: "queue_b", priorityClass: testfixtures.PriorityClass0}: 1, - {queue: "queue_b", priorityClass: testfixtures.PriorityClass1}: 1, - } - - assert.Equal(t, expected, actual) -} - -func TestCalculateFairnessError(t *testing.T) { - tests := map[string]struct { - input map[queuePoolKey]queuePoolData - expected map[string]float64 - }{ - "empty": { - input: map[queuePoolKey]queuePoolData{}, - expected: map[string]float64{}, - }, - "one pool": { - input: map[queuePoolKey]queuePoolData{ - {pool: "poolA", queue: "queueA"}: {actualShare: 0.5, adjustedFairShare: 0.6}, - }, - expected: map[string]float64{ - "poolA": 0.1, - }, - }, - "one pool multiple values": { - input: map[queuePoolKey]queuePoolData{ - {pool: "poolA", queue: "queueA"}: {actualShare: 0.5, adjustedFairShare: 0.6}, - {pool: "poolA", queue: "queueB"}: {actualShare: 0.1, adjustedFairShare: 0.3}, - }, - expected: map[string]float64{ - "poolA": 0.3, - }, - }, - "one pool one value above fair sahre": { - input: map[queuePoolKey]queuePoolData{ - {pool: "poolA", queue: "queueA"}: {actualShare: 0.5, adjustedFairShare: 0.6}, - {pool: "poolA", queue: "queueB"}: {actualShare: 0.3, adjustedFairShare: 0.1}, - }, - expected: map[string]float64{ - "poolA": 0.1, - }, - }, - "two pools": { - input: map[queuePoolKey]queuePoolData{ - {pool: "poolA", queue: "queueA"}: {actualShare: 0.5, adjustedFairShare: 0.6}, - {pool: "poolB", queue: "queueB"}: {actualShare: 0.1, adjustedFairShare: 0.6}, - }, - expected: map[string]float64{ - "poolA": 0.1, - "poolB": 0.5, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - fairnessErrors := calculateFairnessError(tc.input) - require.Equal(t, len(tc.expected), len(fairnessErrors)) - for pool, err := range tc.expected { - assert.InDelta(t, err, fairnessErrors[pool], 0.0001, "error for pool %s", pool) - } - }) - } -} diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 2322e7981a8..3b97262d20b 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -21,7 +21,6 @@ import ( "github.com/armadaproject/armada/internal/common/ingest" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/common/util" - "github.com/armadaproject/armada/internal/scheduler/configuration" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" "github.com/armadaproject/armada/internal/scheduler/database" schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" @@ -29,7 +28,9 @@ import ( "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/affinity" "github.com/armadaproject/armada/internal/scheduler/leader" + "github.com/armadaproject/armada/internal/scheduler/metrics" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" "github.com/armadaproject/armada/internal/scheduler/testfixtures" "github.com/armadaproject/armada/internal/scheduleringester" apiconfig "github.com/armadaproject/armada/internal/server/configuration" @@ -113,18 +114,7 @@ var ( } schedulingInfoWithUpdatedPriorityBytes = protoutil.MustMarshall(schedulingInfoWithUpdatedPriority) - schedulerMetrics = NewSchedulerMetrics(configuration.SchedulerMetricsConfig{ - ScheduleCycleTimeHistogramSettings: configuration.HistogramConfig{ - Start: 1, - Factor: 1.1, - Count: 100, - }, - ReconcileCycleTimeHistogramSettings: configuration.HistogramConfig{ - Start: 1, - Factor: 1.1, - Count: 100, - }, - }) + schedulerMetrics, _ = metrics.New(nil, nil, 12*time.Hour) ) var queuedJob = testfixtures.NewJob( @@ -850,7 +840,6 @@ func TestScheduler_TestCycle(t *testing.T) { maxNumberOfAttempts, nodeIdLabel, schedulerMetrics, - nil, ) require.NoError(t, err) sched.EnableAssertions() @@ -1013,7 +1002,6 @@ func TestRun(t *testing.T) { maxNumberOfAttempts, nodeIdLabel, schedulerMetrics, - nil, ) require.NoError(t, err) sched.EnableAssertions() @@ -1239,7 +1227,6 @@ func TestScheduler_TestSyncState(t *testing.T) { maxNumberOfAttempts, nodeIdLabel, schedulerMetrics, - nil, ) require.NoError(t, err) sched.EnableAssertions() @@ -1358,14 +1345,14 @@ type testSchedulingAlgo struct { persisted bool } -func (t *testSchedulingAlgo) Schedule(_ *armadacontext.Context, txn *jobdb.Txn) (*SchedulerResult, error) { +func (t *testSchedulingAlgo) Schedule(_ *armadacontext.Context, txn *jobdb.Txn) (*schedulerresult.SchedulerResult, error) { t.numberOfScheduleCalls++ if t.shouldError { return nil, errors.New("error scheduling jobs") } if t.persisted { // Exit right away if decisions have already been persisted. - return &SchedulerResult{}, nil + return &schedulerresult.SchedulerResult{}, nil } preemptedJobs := make([]*jobdb.Job, 0, len(t.jobsToPreempt)) scheduledJobs := make([]*jobdb.Job, 0, len(t.jobsToSchedule)) @@ -1424,8 +1411,8 @@ func NewSchedulerResultForTest[S ~[]T, T *jobdb.Job]( preemptedJobs S, scheduledJobs S, nodeIdByJobId map[string]string, -) *SchedulerResult { - return &SchedulerResult{ +) *schedulerresult.SchedulerResult { + return &schedulerresult.SchedulerResult{ PreemptedJobs: schedulercontext.JobSchedulingContextsFromJobs(preemptedJobs), ScheduledJobs: schedulercontext.JobSchedulingContextsFromJobs(scheduledJobs), NodeIdByJobId: nodeIdByJobId, @@ -2346,7 +2333,6 @@ func TestCycleConsistency(t *testing.T) { maxNumberOfAttempts, nodeIdLabel, schedulerMetrics, - nil, ) require.NoError(t, err) scheduler.clock = testClock diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index 746fb5e8caf..4afdd5d9333 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -247,11 +247,9 @@ func Run(config schedulerconfig.Configuration) error { resourceListFactory, floatingResourceTypes, ) - schedulingRoundMetrics := NewSchedulerMetrics(config.Metrics.Metrics) - if err := prometheus.Register(schedulingRoundMetrics); err != nil { - return errors.WithStack(err) - } - schedulerMetrics, err := metrics.New(config.SchedulerMetrics) + + schedulerMetrics, err := metrics.New( + config.Metrics.TrackedErrorRegexes, config.Metrics.TrackedResourceNames, config.Metrics.JobStateMetricsResetInterval) if err != nil { return err } @@ -272,7 +270,6 @@ func Run(config schedulerconfig.Configuration) error { config.ExecutorTimeout, config.Scheduling.MaxRetries+1, config.Scheduling.NodeIdLabel, - schedulingRoundMetrics, schedulerMetrics, ) if err != nil { @@ -340,7 +337,7 @@ func createLeaderController(ctx *armadacontext.Context, config schedulerconfig.L func loadClusterConfig(ctx *armadacontext.Context) (*rest.Config, error) { config, err := rest.InClusterConfig() - if err == rest.ErrNotInCluster { + if errors.Is(err, rest.ErrNotInCluster) { ctx.Info("Running with default client configuration") rules := clientcmd.NewDefaultClientConfigLoadingRules() overrides := &clientcmd.ConfigOverrides{} diff --git a/internal/scheduler/result.go b/internal/scheduler/schedulerresult/result.go similarity index 98% rename from internal/scheduler/result.go rename to internal/scheduler/schedulerresult/result.go index 224439855ff..7743925fd1d 100644 --- a/internal/scheduler/result.go +++ b/internal/scheduler/schedulerresult/result.go @@ -1,4 +1,4 @@ -package scheduler +package schedulerresult import ( schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 263e765cab2..420c5fe2843 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -30,6 +30,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/queue" "github.com/armadaproject/armada/internal/scheduler/reports" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" "github.com/armadaproject/armada/pkg/api" ) @@ -38,7 +39,7 @@ import ( type SchedulingAlgo interface { // Schedule should assign jobs to nodes. // Any jobs that are scheduled should be marked as such in the JobDb using the transaction provided. - Schedule(*armadacontext.Context, *jobdb.Txn) (*SchedulerResult, error) + Schedule(*armadacontext.Context, *jobdb.Txn) (*schedulerresult.SchedulerResult, error) } // FairSchedulingAlgo is a SchedulingAlgo based on PreemptingQueueScheduler. @@ -96,13 +97,13 @@ func NewFairSchedulingAlgo( func (l *FairSchedulingAlgo) Schedule( ctx *armadacontext.Context, txn *jobdb.Txn, -) (*SchedulerResult, error) { +) (*schedulerresult.SchedulerResult, error) { var cancel context.CancelFunc if l.maxSchedulingDuration != 0 { ctx, cancel = armadacontext.WithTimeout(ctx, l.maxSchedulingDuration) defer cancel() } - overallSchedulerResult := &SchedulerResult{ + overallSchedulerResult := &schedulerresult.SchedulerResult{ NodeIdByJobId: make(map[string]string), } @@ -176,8 +177,8 @@ func (l *FairSchedulingAlgo) Schedule( l.schedulingContextRepository.StoreSchedulingContext(sctx) } - preemptedJobs := PreemptedJobsFromSchedulerResult(schedulerResult) - scheduledJobs := ScheduledJobsFromSchedulerResult(schedulerResult) + preemptedJobs := schedulerresult.PreemptedJobsFromSchedulerResult(schedulerResult) + scheduledJobs := schedulerresult.ScheduledJobsFromSchedulerResult(schedulerResult) if err := txn.Upsert(preemptedJobs); err != nil { return nil, err @@ -404,7 +405,7 @@ func (l *FairSchedulingAlgo) schedulePool( fsctx *fairSchedulingAlgoContext, pool string, executors []*schedulerobjects.Executor, -) (*SchedulerResult, *schedulercontext.SchedulingContext, error) { +) (*schedulerresult.SchedulerResult, *schedulercontext.SchedulingContext, error) { nodeDb, err := nodedb.NewNodeDb( l.schedulingConfig.PriorityClasses, l.schedulingConfig.IndexedResources, diff --git a/internal/scheduler/scheduling_algo_test.go b/internal/scheduler/scheduling_algo_test.go index 2b0625b8707..a6424ee3a24 100644 --- a/internal/scheduler/scheduling_algo_test.go +++ b/internal/scheduler/scheduling_algo_test.go @@ -20,6 +20,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/nodedb" "github.com/armadaproject/armada/internal/scheduler/reports" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/schedulerresult" "github.com/armadaproject/armada/internal/scheduler/testfixtures" "github.com/armadaproject/armada/pkg/api" ) @@ -507,7 +508,7 @@ func TestSchedule(t *testing.T) { require.NoError(t, err) // Check that the expected preemptions took place. - preemptedJobs := PreemptedJobsFromSchedulerResult(schedulerResult) + preemptedJobs := schedulerresult.PreemptedJobsFromSchedulerResult(schedulerResult) actualPreemptedJobsByExecutorIndexAndNodeIndex := make(map[int]map[int][]int) for _, job := range preemptedJobs { executorIndex := executorIndexByJobId[job.Id()] @@ -532,7 +533,7 @@ func TestSchedule(t *testing.T) { } // Check that jobs were scheduled as expected. - scheduledJobs := ScheduledJobsFromSchedulerResult(schedulerResult) + scheduledJobs := schedulerresult.ScheduledJobsFromSchedulerResult(schedulerResult) actualScheduledIndices := make([]int, 0) for _, job := range scheduledJobs { actualScheduledIndices = append(actualScheduledIndices, queueIndexByJobId[job.Id()]) diff --git a/third_party/airflow/armada/__init__.py b/third_party/airflow/armada/__init__.py new file mode 100644 index 00000000000..a0f32fe1618 --- /dev/null +++ b/third_party/airflow/armada/__init__.py @@ -0,0 +1,14 @@ +from airflow.serialization.serde import _extra_allowed + +_extra_allowed.add("armada.model.RunningJobContext") +_extra_allowed.add("armada.model.GrpcChannelArgs") + + +def get_provider_info(): + return { + "package-name": "armada-airflow", + "name": "Armada Airflow Operator", + "description": "Armada Airflow Operator.", + "extra-links": ["armada.operators.armada.LookoutLink"], + "versions": ["1.0.0"], + } diff --git a/third_party/airflow/armada/auth.py b/third_party/airflow/armada/auth.py index 16275dbc343..6bf45df780f 100644 --- a/third_party/airflow/armada/auth.py +++ b/third_party/airflow/armada/auth.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Protocol, Tuple +from typing import Protocol """ We use this interface for objects fetching Kubernetes auth tokens. Since it's used within the Trigger, it must be serialisable.""" @@ -6,5 +6,3 @@ class TokenRetriever(Protocol): def get_token(self) -> str: ... - - def serialize(self) -> Tuple[str, Dict[str, Any]]: ... diff --git a/third_party/airflow/armada/hooks.py b/third_party/airflow/armada/hooks.py new file mode 100644 index 00000000000..a894d09249e --- /dev/null +++ b/third_party/airflow/armada/hooks.py @@ -0,0 +1,129 @@ +import dataclasses +import json +import threading +from functools import cached_property +from typing import Dict, Optional + +import grpc +from airflow.exceptions import AirflowException +from airflow.serialization.serde import serialize +from airflow.utils.log.logging_mixin import LoggingMixin +from armada.model import GrpcChannelArgs +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.client import ArmadaClient +from armada_client.typings import JobState +from pendulum import DateTime + +from .model import RunningJobContext + + +class ArmadaClientFactory: + CLIENTS_LOCK = threading.Lock() + CLIENTS: Dict[str, ArmadaClient] = {} + + @staticmethod + def client_for(args: GrpcChannelArgs) -> ArmadaClient: + """ + Armada clients, maintain GRPC connection to Armada API. + We cache them per channel args config in class level cache. + + Access to this method can be from multiple-threads. + """ + channel_args_key = json.dumps(serialize(args)) + with ArmadaClientFactory.CLIENTS_LOCK: + if channel_args_key not in ArmadaClientFactory.CLIENTS: + ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( + channel=ArmadaClientFactory._create_channel(args) + ) + return ArmadaClientFactory.CLIENTS[channel_args_key] + + @staticmethod + def _create_channel(args: GrpcChannelArgs) -> grpc.Channel: + if args.auth is None: + return grpc.insecure_channel( + target=args.target, options=args.options, compression=args.compression + ) + + return grpc.secure_channel( + target=args.target, + options=args.options, + compression=args.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(args.auth), + ), + ) + + +class ArmadaHook(LoggingMixin): + def __init__(self, args: GrpcChannelArgs): + self.args = args + + @cached_property + def client(self): + return ArmadaClientFactory.client_for(self.args) + + def cancel_job(self, job_context: RunningJobContext) -> RunningJobContext: + try: + result = self.client.cancel_jobs( + queue=job_context.armada_queue, + job_set_id=job_context.job_set_id, + job_id=job_context.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {job_context.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") + finally: + return dataclasses.replace(job_context, job_state=JobState.CANCELLED.name) + + def submit_job( + self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem + ) -> RunningJobContext: + resp = self.client.submit_jobs(queue, job_set_id, [job_request]) + num_responses = len(resp.job_response_items) + + # We submitted exactly one job to armada, so we expect a single response + if num_responses != 1: + raise AirflowException( + f"No valid received from Armada (expected 1 job to be created " + f"but got {num_responses})" + ) + job = resp.job_response_items[0] + + # Throw if armada told us we had submitted something bad + if job.error: + raise AirflowException(f"Error submitting job to Armada: {job.error}") + + return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow()) + + def refresh_context( + self, job_context: RunningJobContext, tracking_url: str + ) -> RunningJobContext: + response = self.client.get_job_status([job_context.job_id]) + state = JobState(response.job_states[job_context.job_id]) + if state != job_context.state: + self.log.info( + f"job {job_context.job_id} is in state: {state.name}. " + f"{tracking_url}" + ) + + cluster = job_context.cluster + if not cluster: + # Job is running / or completed already + if state == JobState.RUNNING or state.is_terminal(): + run_details = self._get_latest_job_run_details(job_context.job_id) + if run_details: + cluster = run_details.cluster + return dataclasses.replace(job_context, job_state=state.name, cluster=cluster) + + def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + job_details = self.client.get_job_details([job_id]).job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py index 00b9ab59800..91db62420e0 100644 --- a/third_party/airflow/armada/model.py +++ b/third_party/airflow/armada/model.py @@ -1,7 +1,11 @@ -import importlib -from typing import Any, Dict, Optional, Sequence, Tuple +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, Optional, Sequence, Tuple import grpc +from armada_client.typings import JobState +from pendulum import DateTime """ This class exists so that we can retain our connection to the Armada Query API when using the deferrable Armada Airflow Operator. Airflow requires any state @@ -10,73 +14,55 @@ class GrpcChannelArgs: + __version__: ClassVar[int] = 1 + def __init__( self, target: str, options: Optional[Sequence[Tuple[str, Any]]] = None, compression: Optional[grpc.Compression] = None, auth: Optional[grpc.AuthMetadataPlugin] = None, - auth_details: Optional[Dict[str, Any]] = None, ): self.target = target self.options = options self.compression = compression - if auth: - self.auth = auth - elif auth_details: - classpath, kwargs = auth_details - module_path, class_name = classpath.rsplit( - ".", 1 - ) # Split the classpath to module and class name - module = importlib.import_module( - module_path - ) # Dynamically import the module - cls = getattr(module, class_name) # Get the class from the module - self.auth = cls( - **kwargs - ) # Instantiate the class with the deserialized kwargs - else: - self.auth = None + self.auth = auth def serialize(self) -> Dict[str, Any]: - auth_details = self.auth.serialize() if self.auth else None return { "target": self.target, "options": self.options, "compression": self.compression, - "auth_details": auth_details, + "auth": self.auth, } - def channel(self) -> grpc.Channel: - if self.auth is None: - return grpc.insecure_channel( - target=self.target, options=self.options, compression=self.compression - ) + @staticmethod + def deserialize(data: dict[str, Any], version: int) -> GrpcChannelArgs: + if version > GrpcChannelArgs.__version__: + raise TypeError("serialized version > class version") + return GrpcChannelArgs(**data) - return grpc.secure_channel( - target=self.target, - options=self.options, - compression=self.compression, - credentials=grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(self.auth), - ), + def __eq__(self, value: object) -> bool: + if type(value) is not GrpcChannelArgs: + return False + return ( + self.target == value.target + and self.options == value.options + and self.compression == value.compression + and self.auth == value.auth ) - def aio_channel(self) -> grpc.aio.Channel: - if self.auth is None: - return grpc.aio.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.aio.secure_channel( - target=self.target, - options=self.options, - compression=self.compression, - credentials=grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(self.auth), - ), - ) +@dataclass(frozen=True) +class RunningJobContext: + armada_queue: str + job_id: str + job_set_id: str + submit_time: DateTime + cluster: Optional[str] = None + last_log_time: Optional[DateTime] = None + job_state: str = JobState.UNKNOWN.name + + @property + def state(self) -> JobState: + return JobState[self.job_state] diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 7e365417ed3..3f06b99252b 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -17,125 +17,43 @@ # under the License. from __future__ import annotations -import asyncio +import dataclasses import datetime -import functools import os -import threading import time -from dataclasses import dataclass -from functools import cached_property -from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple import jinja2 from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.models import BaseOperator, BaseOperatorLink, XCom +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.serialization.serde import deserialize from airflow.utils.context import Context from airflow.utils.log.logging_mixin import LoggingMixin from armada.auth import TokenRetriever from armada.log_manager import KubernetesPodLogManager from armada.model import GrpcChannelArgs -from armada_client.armada.job_pb2 import JobRunDetails from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient from armada_client.typings import JobState from google.protobuf.json_format import MessageToDict, ParseDict from pendulum import DateTime +from ..hooks import ArmadaHook +from ..model import RunningJobContext +from ..triggers import ArmadaPollJobTrigger +from ..utils import log_exceptions -def log_exceptions(method): - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - try: - return method(self, *args, **kwargs) - except Exception as e: - if hasattr(self, "log") and hasattr(self.log, "error"): - self.log.error(f"Exception in {method.__name__}: {e}") - raise - - return wrapper - - -@dataclass(frozen=False) -class _RunningJobContext: - armada_queue: str - job_set_id: str - job_id: str - state: JobState = JobState.UNKNOWN - start_time: DateTime = DateTime.utcnow() - cluster: Optional[str] = None - last_log_time: Optional[DateTime] = None - - def serialize(self) -> tuple[str, Dict[str, Any]]: - return ( - "armada.operators.armada._RunningJobContext", - { - "armada_queue": self.armada_queue, - "job_set_id": self.job_set_id, - "job_id": self.job_id, - "state": self.state.value, - "start_time": self.start_time, - "cluster": self.cluster, - "last_log_time": self.last_log_time, - }, - ) - - def from_payload(payload: Dict[str, Any]) -> _RunningJobContext: - return _RunningJobContext( - armada_queue=payload["armada_queue"], - job_set_id=payload["job_set_id"], - job_id=payload["job_id"], - state=JobState(payload["state"]), - start_time=payload["start_time"], - cluster=payload["cluster"], - last_log_time=payload["last_log_time"], - ) - - -class _ArmadaPollJobTrigger(BaseTrigger): - def __init__(self, moment: datetime.timedelta, context: _RunningJobContext) -> None: - super().__init__() - self.moment = moment - self.context = context - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "armada.operators.armada._ArmadaPollJobTrigger", - {"moment": self.moment, "context": self.context.serialize()}, - ) - - def __eq__(self, value: object) -> bool: - if not isinstance(value, _ArmadaPollJobTrigger): - return False - return self.moment == value.moment and self.context == value.context - async def run(self) -> AsyncIterator[TriggerEvent]: - while self.moment > DateTime.utcnow(): - await asyncio.sleep(1) - yield TriggerEvent(self.context) +class LookoutLink(BaseOperatorLink): + name = "Lookout" + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): + task_state = XCom.get_value(ti_key=ti_key) + if not task_state: + return "" -class _ArmadaClientFactory: - CLIENTS_LOCK = threading.Lock() - CLIENTS: Dict[str, ArmadaClient] = {} - - @staticmethod - def client_for(args: GrpcChannelArgs) -> ArmadaClient: - """ - Armada clients, maintain GRPC connection to Armada API. - We cache them per channel args config in class level cache. - - Access to this method can be from multiple-threads. - """ - channel_args_key = str(args.serialize()) - with _ArmadaClientFactory.CLIENTS_LOCK: - if channel_args_key not in _ArmadaClientFactory.CLIENTS: - _ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( - channel=args.channel() - ) - return _ArmadaClientFactory.CLIENTS[channel_args_key] + return task_state.get("armada_lookout_url", "") class ArmadaOperator(BaseOperator, LoggingMixin): @@ -146,7 +64,10 @@ class ArmadaOperator(BaseOperator, LoggingMixin): and handles job cancellation if the Airflow task is killed. """ + operator_extra_links = (LookoutLink(),) + template_fields: Sequence[str] = ("job_request", "job_set_prefix") + template_fields_renderers: Dict[str, str] = {"job_request": "py"} """ Initializes a new ArmadaOperator. @@ -158,7 +79,8 @@ class ArmadaOperator(BaseOperator, LoggingMixin): :param armada_queue: The name of the Armada queue to which the job will be submitted. :type armada_queue: str :param job_request: The job to be submitted to Armada. -:type job_request: JobSubmitRequestItem +:type job_request: JobSubmitRequestItem | \ +Callable[[Context, jinja2.Environment], JobSubmitRequestItem] :param job_set_prefix: A string to prepend to the jobSet name. :type job_set_prefix: Optional[str] :param lookout_url_template: Template for creating lookout links. If not specified @@ -177,6 +99,8 @@ class ArmadaOperator(BaseOperator, LoggingMixin): :param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be acknowledged by Armada. :type job_acknowledgement_timeout: int +:param dry_run: Run Operator in dry-run mode - render Armada request and terminate. +:type dry_run: bool :param kwargs: Additional keyword arguments to pass to the BaseOperator. """ @@ -185,7 +109,10 @@ def __init__( name: str, channel_args: GrpcChannelArgs, armada_queue: str, - job_request: JobSubmitRequestItem, + job_request: ( + JobSubmitRequestItem + | Callable[[Context, jinja2.Environment], JobSubmitRequestItem] + ), job_set_prefix: Optional[str] = "", lookout_url_template: Optional[str] = None, poll_interval: int = 30, @@ -195,6 +122,7 @@ def __init__( "operators", "default_deferrable", fallback=True ), job_acknowledgement_timeout: int = 5 * 60, + dry_run: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -210,6 +138,7 @@ def __init__( self.k8s_token_retriever = k8s_token_retriever self.deferrable = deferrable self.job_acknowledgement_timeout = job_acknowledgement_timeout + self.dry_run = dry_run self.job_context = None if self.container_logs and self.k8s_token_retriever is None: @@ -226,32 +155,31 @@ def execute(self, context) -> None: :param context: The execution context provided by Airflow. :type context: Context """ - # We take the job_set_id from Airflow's run_id. This means that all jobs in the - # dag will be in the same jobset. + # We take the job_set_id from Airflow's run_id. + # So all jobs in the dag will be in the same jobset. self.job_set_id = f"{self.job_set_prefix}{context['run_id']}" self._annotate_job_request(context, self.job_request) - # Submit job or reattach to previously submitted job. We always do this - # synchronously. - job_id = self._reattach_or_submit_job( - context, self.armada_queue, self.job_set_id, self.job_request - ) + if self.dry_run: + self.log.info( + f"Running in dry_run mode. job_set_id: {self.job_set_id} \n" + f"{self.job_request}" + ) + return - # Wait until finished - self.job_context = _RunningJobContext( - self.armada_queue, self.job_set_id, job_id, start_time=DateTime.utcnow() + # Submit job or reattach to previously submitted job. + # Always do this synchronously. + self.job_context = self._reattach_or_submit_job( + context, self.job_set_id, self.job_request ) - if self.deferrable: - self._deffered_yield(self.job_context) - else: - self._poll_for_termination(self.job_context) + self._poll_for_termination() - @cached_property - def client(self) -> ArmadaClient: - return _ArmadaClientFactory.client_for(self.channel_args) + @property + def hook(self) -> ArmadaHook: + return ArmadaHook(self.channel_args) - @cached_property + @property def pod_manager(self) -> KubernetesPodLogManager: return KubernetesPodLogManager(token_retriever=self.k8s_token_retriever) @@ -270,120 +198,107 @@ def render_template_fields( :param context: Airflow Context dict wi1th values to apply on content :param jinja_env: jinja’s environment to use for rendering. """ + if callable(self.job_request): + if not jinja_env: + jinja_env = self.get_template_env() + self.job_request = self.job_request(context, jinja_env) + self.job_request = MessageToDict( self.job_request, preserving_proto_field_name=True ) super().render_template_fields(context, jinja_env) self.job_request = ParseDict(self.job_request, JobSubmitRequestItem()) - def _cancel_job(self, job_context) -> None: - try: - result = self.client.cancel_jobs( - queue=job_context.armada_queue, - job_set_id=job_context.job_set_id, - job_id=job_context.job_id, - ) - if len(list(result.cancelled_ids)) > 0: - self.log.info(f"Cancelled job with id {result.cancelled_ids}") - else: - self.log.warning(f"Failed to cancel job with id {job_context.job_id}") - except Exception as e: - self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") - def on_kill(self) -> None: if self.job_context is not None: self.log.info( f"on_kill called, " - "cancelling job with id {self.job_context.job_id} in queue " + f"cancelling job with id {self.job_context.job_id} in queue " f"{self.job_context.armada_queue}" ) - self._cancel_job(self.job_context) + self.hook.cancel_job(self.job_context) + self.job_context = None - def _trigger_tracking_message(self, job_id: str): + def lookout_url(self, job_id): if self.lookout_url_template: - return ( - f"Job details available at " - f'{self.lookout_url_template.replace("", job_id)}' - ) + return self.lookout_url_template.replace("", job_id) + return None + + def _trigger_tracking_message(self, job_id): + url = self.lookout_url(job_id) + if url: + return f"Job details available at {url}" return "" - def _deffered_yield(self, context: _RunningJobContext): - self.defer( - timeout=self.execution_timeout, - trigger=_ArmadaPollJobTrigger( - DateTime.utcnow() + datetime.timedelta(seconds=self.poll_interval), - context, - ), - method_name="_deffered_poll_for_termination", - ) + def _yield(self): + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=ArmadaPollJobTrigger( + DateTime.utcnow() + datetime.timedelta(seconds=self.poll_interval), + self.job_context, + self.channel_args, + ), + method_name="_trigger_reentry", + ) + else: + time.sleep(self.poll_interval) - @log_exceptions - def _deffered_poll_for_termination( + def _trigger_reentry( self, context: Context, event: Tuple[str, Dict[str, Any]] ) -> None: - job_run_context = _RunningJobContext.from_payload(event[1]) - while job_run_context.state.is_active(): - job_run_context = self._check_job_status_and_fetch_logs(job_run_context) - if job_run_context.state.is_active(): - self._deffered_yield(job_run_context) - - self._running_job_terminated(job_run_context) + self.job_context = deserialize(event) + self._poll_for_termination() def _reattach_or_submit_job( self, context: Context, - queue: str, job_set_id: str, job_request: JobSubmitRequestItem, - ) -> str: + ) -> RunningJobContext: + # Try to re-initialize job_context from xcom if it exist. ti = context["ti"] - existing_id = ti.xcom_pull( + existing_run = ti.xcom_pull( dag_id=ti.dag_id, task_ids=ti.task_id, key=f"{ti.try_number}" ) - if existing_id is not None: + if existing_run is not None: self.log.info( - f"Attached to existing job with id {existing_id['armada_job_id']}." - f" {self._trigger_tracking_message(existing_id['armada_job_id'])}" + f"Attached to existing job with id {existing_run['armada_job_id']}." + f" {self._trigger_tracking_message(existing_run['armada_job_id'])}" ) - return existing_id["armada_job_id"] - - job_id = self._submit_job(queue, job_set_id, job_request) - self.log.info( - f"Submitted job with id {job_id}. {self._trigger_tracking_message(job_id)}" - ) - ti.xcom_push(key=f"{ti.try_number}", value={"armada_job_id": job_id}) - return job_id - - def _submit_job( - self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem - ) -> str: - resp = self.client.submit_jobs(queue, job_set_id, [job_request]) - num_responses = len(resp.job_response_items) - - # We submitted exactly one job to armada, so we expect a single response - if num_responses != 1: - raise AirflowException( - f"No valid received from Armada (expected 1 job to be created " - f"but got {num_responses}" + return RunningJobContext( + armada_queue=existing_run["armada_queue"], + job_id=existing_run["armada_job_id"], + job_set_id=existing_run["armada_job_set_id"], + submit_time=DateTime.utcnow(), ) - job = resp.job_response_items[0] - - # Throw if armada told us we had submitted something bad - if job.error: - raise AirflowException(f"Error submitting job to Armada: {job.error}") - return job.job_id + # We haven't got a running job, submit a new one and persist state to xcom. + ctx = self.hook.submit_job(self.armada_queue, job_set_id, job_request) + tracking_msg = self._trigger_tracking_message(ctx.job_id) + self.log.info(f"Submitted job with id {ctx.job_id}. {tracking_msg}") + + ti.xcom_push( + key=f"{ti.try_number}", + value={ + "armada_queue": ctx.armada_queue, + "armada_job_id": ctx.job_id, + "armada_job_set_id": ctx.job_set_id, + "armada_lookout_url": self.lookout_url(ctx.job_id), + }, + ) + return ctx - def _poll_for_termination(self, context: _RunningJobContext) -> None: - while context.state.is_active(): - context = self._check_job_status_and_fetch_logs(context) - if context.state.is_active(): - time.sleep(self.poll_interval) + def _poll_for_termination(self) -> None: + while self.job_context.state.is_active(): + self._check_job_status_and_fetch_logs() + if self.job_context.state.is_active(): + self._yield() - self._running_job_terminated(context) + self._running_job_terminated(self.job_context) - def _running_job_terminated(self, context: _RunningJobContext): + def _running_job_terminated(self, context: RunningJobContext): self.log.info( f"job {context.job_id} terminated with state: {context.state.name}" ) @@ -393,57 +308,43 @@ def _running_job_terminated(self, context: _RunningJobContext): f"Final status was {context.state.name}" ) - @log_exceptions - def _check_job_status_and_fetch_logs( - self, context: _RunningJobContext - ) -> _RunningJobContext: - response = self.client.get_job_status([context.job_id]) - state = JobState(response.job_states[context.job_id]) - if state != context.state: - self.log.info( - f"job {context.job_id} is in state: {state.name}. " - f"{self._trigger_tracking_message(context.job_id)}" - ) - context.state = state - - if context.state == JobState.UNKNOWN: + def _not_acknowledged_within_timeout(self) -> bool: + if self.job_context.state == JobState.UNKNOWN: if ( - DateTime.utcnow().diff(context.start_time).in_seconds() + DateTime.utcnow().diff(self.job_context.submit_time).in_seconds() > self.job_acknowledgement_timeout ): - self.log.info( - f"Job {context.job_id} not acknowledged by the Armada within " - f"timeout ({self.job_acknowledgement_timeout}), terminating" - ) - self._cancel_job(context) - context.state = JobState.CANCELLED - return context + return True + return False - if self.container_logs and not context.cluster: - if context.state == JobState.RUNNING or context.state.is_terminal(): - run_details = self._get_latest_job_run_details(context.job_id) - context.cluster = run_details.cluster + @log_exceptions + def _check_job_status_and_fetch_logs(self) -> None: + self.job_context = self.hook.refresh_context( + self.job_context, self._trigger_tracking_message(self.job_context.job_id) + ) - if context.cluster: + if self._not_acknowledged_within_timeout(): + self.log.info( + f"Job {self.job_context.job_id} not acknowledged by the Armada within " + f"timeout ({self.job_acknowledgement_timeout}), terminating" + ) + self.job_context = self.hook.cancel_job(self.job_context) + return + + if self.job_context.cluster and self.container_logs: try: - context.last_log_time = self.pod_manager.fetch_container_logs( - k8s_context=context.cluster, + last_log_time = self.pod_manager.fetch_container_logs( + k8s_context=self.job_context.cluster, namespace=self.job_request.namespace, - pod=f"armada-{context.job_id}-0", + pod=f"armada-{self.job_context.job_id}-0", container=self.container_logs, - since_time=context.last_log_time, + since_time=self.job_context.last_log_time, + ) + self.job_context = dataclasses.replace( + self.job_context, last_log_time=last_log_time ) except Exception as e: self.log.warning(f"Error fetching logs {e}") - return context - - def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: - job_details = self.client.get_job_details([job_id]).job_details[job_id] - if job_details and job_details.latest_run_id: - for run in job_details.job_runs: - if run.run_id == job_details.latest_run_id: - return run - return None @staticmethod def _annotate_job_request(context, request: JobSubmitRequestItem): diff --git a/third_party/airflow/armada/plugin.py b/third_party/airflow/armada/plugin.py new file mode 100644 index 00000000000..c7694566914 --- /dev/null +++ b/third_party/airflow/armada/plugin.py @@ -0,0 +1,10 @@ +from airflow.plugins_manager import AirflowPlugin + +from .armada.operators.armada import LookoutLink + + +class AirflowExtraLinkPlugin(AirflowPlugin): + name = "extra_link_plugin" + operator_extra_links = [ + LookoutLink(), + ] diff --git a/third_party/airflow/armada/triggers.py b/third_party/airflow/armada/triggers.py new file mode 100644 index 00000000000..2ea44e16c0c --- /dev/null +++ b/third_party/airflow/armada/triggers.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import asyncio +from datetime import timedelta +from typing import Any, AsyncIterator, ClassVar, Dict + +from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance +from airflow.serialization.serde import deserialize, serialize +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState +from pendulum import DateTime +from sqlalchemy.orm.session import Session + +from .hooks import ArmadaHook +from .model import GrpcChannelArgs, RunningJobContext +from .utils import log_exceptions + + +class ArmadaPollJobTrigger(BaseTrigger): + __version__: ClassVar[int] = 1 + + @log_exceptions + def __init__( + self, + moment: timedelta, + context: RunningJobContext | tuple[str, Dict[str, Any]], + channel_args: GrpcChannelArgs | tuple[str, Dict[str, Any]], + ) -> None: + super().__init__() + + self.moment = moment + if type(context) is RunningJobContext: + self.context = context + else: + self.context = deserialize(context) + + if type(channel_args) is GrpcChannelArgs: + self.channel_args = channel_args + else: + self.channel_args = deserialize(channel_args) + + @log_exceptions + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "armada.triggers.ArmadaPollJobTrigger", + { + "moment": self.moment, + "context": serialize(self.context), + "channel_args": serialize(self.channel_args), + }, + ) + + @log_exceptions + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s,task_id: %s, " + "run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def should_cancel_job(self) -> bool: + """ + We only want to cancel jobs when task is being marked Failed/Succeeded. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + + def __eq__(self, value: object) -> bool: + if not isinstance(value, ArmadaPollJobTrigger): + return False + return ( + self.moment == value.moment + and self.context == value.context + and self.channel_args == value.channel_args + ) + + @property + def hook(self) -> ArmadaHook: + return ArmadaHook(self.channel_args) + + @log_exceptions + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + while self.moment > DateTime.utcnow(): + await asyncio.sleep(1) + yield TriggerEvent(serialize(self.context)) + except asyncio.CancelledError: + if self.should_cancel_job(): + self.hook.cancel_job(self.context) + raise diff --git a/third_party/airflow/armada/utils.py b/third_party/airflow/armada/utils.py new file mode 100644 index 00000000000..e700a1bbc5e --- /dev/null +++ b/third_party/airflow/armada/utils.py @@ -0,0 +1,14 @@ +import functools + + +def log_exceptions(method): + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + try: + return method(self, *args, **kwargs) + except Exception as e: + if hasattr(self, "log") and hasattr(self.log, "error"): + self.log.error(f"Exception in {method.__name__}: {e}") + raise + + return wrapper diff --git a/third_party/airflow/docs/source/conf.py b/third_party/airflow/docs/source/conf.py index 10d3949aee8..a7e2f5a75bb 100644 --- a/third_party/airflow/docs/source/conf.py +++ b/third_party/airflow/docs/source/conf.py @@ -13,14 +13,14 @@ import os import sys -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'python-armadaairflowoperator' -copyright = '2022 Armada Project' -author = 'armada@armadaproject.io' +project = "python-armadaairflowoperator" +copyright = "2022 Armada Project" +author = "armada@armadaproject.io" # -- General configuration --------------------------------------------------- @@ -28,12 +28,12 @@ # Jekyll is the style of markdown used by github pages; using # sphinx_jekyll_builder here allows us to generate docs as # markdown files. -extensions = ['sphinx.ext.autodoc', 'sphinx_jekyll_builder'] +extensions = ["sphinx.ext.autodoc", "sphinx_jekyll_builder"] # This setting puts information about typing in the description section instead # of in the function signature directly. This makes rendered content look much # better in our gh-pages template that renders the generated markdown. -autodoc_typehints = 'description' +autodoc_typehints = "description" # Add any paths that contain templates here, relative to this directory. templates_path = [] @@ -49,7 +49,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index 8f8fb538a57..3c8471bde44 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "armada_airflow" -version = "1.0.1" +version = "1.0.2" description = "Armada Airflow Operator" readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] @@ -31,6 +31,9 @@ test = ["pytest==7.3.1", "coverage==7.3.2", "pytest-asyncio==0.21.1", # note(JayF): sphinx-jekyll-builder was broken by sphinx-markdown-builder 0.6 -- so pin to 0.5.5 docs = ["sphinx==7.1.2", "sphinx-jekyll-builder==0.3.0", "sphinx-toolbox==3.2.0b1", "sphinx-markdown-builder==0.5.5"] +[project.entry-points.apache_airflow_provider] +provider_info = "armada.__init__:get_provider_info" + [project.urls] repository='https://github.com/armadaproject/armada' @@ -39,7 +42,7 @@ include = ["armada_airflow*"] [tool.black] line-length = 88 -target-version = ['py310'] +target-version = ['py38', 'py39', 'py310'] include = ''' /( armada diff --git a/third_party/airflow/test/integration/test_airflow_operator_logic.py b/third_party/airflow/test/integration/test_airflow_operator_logic.py index 4bc3c43418e..594d3d5eaec 100644 --- a/third_party/airflow/test/integration/test_airflow_operator_logic.py +++ b/third_party/airflow/test/integration/test_airflow_operator_logic.py @@ -85,9 +85,7 @@ def sleep_pod(image: str): ] -def test_success_job( - client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker -): +def test_success_job(client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs): job_set_name = f"test-{uuid.uuid1()}" job = client.submit_jobs( queue=DEFAULT_QUEUE, @@ -96,10 +94,11 @@ def test_success_job( ) job_id = job.job_response_items[0].job_id - mocker.patch( - "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", - return_value=job_id, - ) + context["ti"].xcom_pull.return_value = { + "armada_queue": DEFAULT_QUEUE, + "armada_job_id": job_id, + "armada_job_set_id": job_set_name, + } operator = ArmadaOperator( task_id=DEFAULT_TASK_ID, @@ -113,13 +112,11 @@ def test_success_job( operator.execute(context) - response = operator.client.get_job_status([job_id]) + response = client.get_job_status([job_id]) assert JobState(response.job_states[job_id]) == JobState.SUCCEEDED -def test_bad_job( - client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker -): +def test_bad_job(client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs): job_set_name = f"test-{uuid.uuid1()}" job = client.submit_jobs( queue=DEFAULT_QUEUE, @@ -128,10 +125,11 @@ def test_bad_job( ) job_id = job.job_response_items[0].job_id - mocker.patch( - "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", - return_value=job_id, - ) + context["ti"].xcom_pull.return_value = { + "armada_queue": DEFAULT_QUEUE, + "armada_job_id": job_id, + "armada_job_set_id": job_set_name, + } operator = ArmadaOperator( task_id=DEFAULT_TASK_ID, @@ -149,7 +147,7 @@ def test_bad_job( "Operator did not raise AirflowException on job failure as expected" ) except AirflowException: # Expected - response = operator.client.get_job_status([job_id]) + response = client.get_job_status([job_id]) assert JobState(response.job_states[job_id]) == JobState.FAILED except Exception as e: pytest.fail( @@ -159,7 +157,7 @@ def test_bad_job( def success_job( - task_number: int, context: Any, channel_args: GrpcChannelArgs + task_number: int, context: Any, channel_args: GrpcChannelArgs, client: ArmadaClient ) -> JobState: operator = ArmadaOperator( task_id=f"{DEFAULT_TASK_ID}_{task_number}", @@ -173,7 +171,7 @@ def success_job( operator.execute(context) - response = operator.client.get_job_status([operator.job_id]) + response = client.get_job_status([operator.job_id]) return JobState(response.job_states[operator.job_id]) @@ -182,7 +180,9 @@ def test_parallel_execution( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(5): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] @@ -199,7 +199,9 @@ def test_parallel_execution_large( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(80): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] @@ -216,7 +218,9 @@ def test_parallel_execution_huge( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(500): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py deleted file mode 100644 index 85129000ad1..00000000000 --- a/third_party/airflow/test/operators/test_armada.py +++ /dev/null @@ -1,324 +0,0 @@ -import unittest -from datetime import timedelta -from math import ceil -from unittest.mock import MagicMock, PropertyMock, patch - -from airflow.exceptions import AirflowException -from armada.model import GrpcChannelArgs -from armada.operators.armada import ( - ArmadaOperator, - _ArmadaPollJobTrigger, - _RunningJobContext, -) -from armada_client.armada import job_pb2, submit_pb2 -from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -from armada_client.typings import JobState -from pendulum import UTC, DateTime - -DEFAULT_CURRENT_TIME = DateTime(2024, 8, 7, tzinfo=UTC) -DEFAULT_JOB_ID = "test_job" -DEFAULT_TASK_ID = "test_task_1" -DEFAULT_DAG_ID = "test_dag_1" -DEFAULT_RUN_ID = "test_run_1" -DEFAULT_QUEUE = "test_queue_1" -DEFAULT_POLLING_INTERVAL = 30 -DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 - - -class TestArmadaOperator(unittest.TestCase): - def setUp(self): - # Set up a mock context - mock_ti = MagicMock() - mock_ti.task_id = DEFAULT_TASK_ID - mock_dag = MagicMock() - mock_dag.dag_id = DEFAULT_DAG_ID - self.context = { - "ti": mock_ti, - "run_id": DEFAULT_RUN_ID, - "dag": mock_dag, - } - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_execute(self, mock_client_fn, _): - test_cases = [ - { - "name": "Job Succeeds", - "statuses": [submit_pb2.RUNNING, submit_pb2.SUCCEEDED], - "success": True, - }, - { - "name": "Job Failed", - "statuses": [submit_pb2.RUNNING, submit_pb2.FAILED], - "success": False, - }, - { - "name": "Job cancelled", - "statuses": [submit_pb2.RUNNING, submit_pb2.CANCELLED], - "success": False, - }, - { - "name": "Job preempted", - "statuses": [submit_pb2.RUNNING, submit_pb2.PREEMPTED], - "success": False, - }, - { - "name": "Job Succeeds but takes a lot of transitions", - "statuses": [ - submit_pb2.SUBMITTED, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.SUCCEEDED, - ], - "success": True, - }, - ] - - for test_case in test_cases: - with self.subTest(test_case=test_case["name"]): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[ - submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID) - ] - ) - - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in test_case["statuses"] - ] - - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = None - - try: - operator.execute(self.context) - self.assertTrue(test_case["success"]) - except AirflowException: - self.assertFalse(test_case["success"]) - return - - self.assertEqual(mock_client.submit_jobs.call_count, 1) - self.assertEqual( - mock_client.get_job_status.call_count, len(test_case["statuses"]) - ) - - @patch("time.sleep", return_value=None) - @patch( - "armada.operators.armada.ArmadaOperator._cancel_job", new_callable=PropertyMock - ) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=-1, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [submit_pb2.UNKNOWN, submit_pb2.UNKNOWN] - ] - - self.context["ti"].xcom_pull.return_value = None - with self.assertRaises(AirflowException): - operator.execute(self.context) - self.assertEqual(mock_on_kill.call_count, 1) - - """We call on_kill by triggering the job unacknowledged timeout""" - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_on_kill_cancels_job(self, mock_client_fn, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=-1, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [ - submit_pb2.UNKNOWN - for _ in range( - 1 - + ceil( - DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL - ) - ) - ] - ] - - self.context["ti"].xcom_pull.return_value = None - with self.assertRaises(AirflowException): - operator.execute(self.context) - self.assertEqual(mock_client.cancel_jobs.call_count, 1) - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_job_reattaches(self, mock_client_fn, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=10, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [ - submit_pb2.SUCCEEDED - for _ in range( - 1 - + ceil( - DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL - ) - ) - ] - ] - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = {"armada_job_id": DEFAULT_JOB_ID} - - operator.execute(self.context) - self.assertEqual(mock_client.submit_jobs.call_count, 0) - - -class TestArmadaOperatorDeferrable(unittest.IsolatedAsyncioTestCase): - def setUp(self): - # Set up a mock context - mock_ti = MagicMock() - mock_ti.task_id = DEFAULT_TASK_ID - mock_dag = MagicMock() - mock_dag.dag_id = DEFAULT_DAG_ID - self.context = { - "ti": mock_ti, - "run_id": DEFAULT_RUN_ID, - "dag": mock_dag, - } - - @patch("pendulum.DateTime.utcnow") - @patch("armada.operators.armada.ArmadaOperator.defer") - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_execute_deferred(self, mock_client_fn, mock_defer_fn, mock_datetime_now): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=True, - ) - - mock_datetime_now.return_value = DEFAULT_CURRENT_TIME - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = None - - operator.execute(self.context) - self.assertEqual(mock_client.submit_jobs.call_count, 1) - mock_defer_fn.assert_called_with( - timeout=operator.execution_timeout, - trigger=_ArmadaPollJobTrigger( - moment=DEFAULT_CURRENT_TIME + timedelta(seconds=operator.poll_interval), - context=_RunningJobContext( - armada_queue=DEFAULT_QUEUE, - job_set_id=operator.job_set_id, - job_id=DEFAULT_JOB_ID, - state=JobState.UNKNOWN, - start_time=DEFAULT_CURRENT_TIME, - cluster=None, - last_log_time=None, - ), - ), - method_name="_deffered_poll_for_termination", - ) - - def test_templating(self): - """Tests templating for both the job_prefix and the pod spec""" - prefix = "{{ run_id }}" - pod_arg = "{{ run_id }}" - - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="sleep", - image="alpine:3.20.2", - args=[pod_arg], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - job = JobSubmitRequestItem(priority=1, pod_spec=pod, namespace="armada") - - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=job, - job_set_prefix=prefix, - task_id=DEFAULT_TASK_ID, - deferrable=True, - ) - - operator.render_template_fields(self.context) - - self.assertEqual(operator.job_set_prefix, "test_run_1") - self.assertEqual( - operator.job_request.pod_spec.containers[0].args[0], "test_run_1" - ) diff --git a/third_party/airflow/test/__init__.py b/third_party/airflow/test/unit/__init__.py similarity index 100% rename from third_party/airflow/test/__init__.py rename to third_party/airflow/test/unit/__init__.py diff --git a/third_party/airflow/test/operators/__init__.py b/third_party/airflow/test/unit/operators/__init__.py similarity index 100% rename from third_party/airflow/test/operators/__init__.py rename to third_party/airflow/test/unit/operators/__init__.py diff --git a/third_party/airflow/test/unit/operators/test_armada.py b/third_party/airflow/test/unit/operators/test_armada.py new file mode 100644 index 00000000000..d2aab33cce4 --- /dev/null +++ b/third_party/airflow/test/unit/operators/test_armada.py @@ -0,0 +1,262 @@ +import dataclasses +from datetime import timedelta +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +from airflow.exceptions import AirflowException, TaskDeferred +from armada.model import GrpcChannelArgs, RunningJobContext +from armada.operators.armada import ArmadaOperator +from armada.triggers import ArmadaPollJobTrigger +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.typings import JobState +from pendulum import UTC, DateTime + +DEFAULT_CURRENT_TIME = DateTime(2024, 8, 7, tzinfo=UTC) +DEFAULT_JOB_ID = "test_job" +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_JOB_SET = "prefix-test_run_1" +DEFAULT_QUEUE = "test_queue_1" +DEFAULT_CLUSTER = "cluster-1" + + +def default_hook() -> MagicMock: + mock = MagicMock() + job_context = running_job_context() + mock.submit_job.return_value = job_context + mock.refresh_context.return_value = dataclasses.replace( + job_context, job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER + ) + mock.cancel_job.return_value = dataclasses.replace( + job_context, job_state=JobState.CANCELLED.name + ) + + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_operator_dependencies(): + # We no-op time.sleep in tests. + with patch("time.sleep", return_value=None) as sleep, patch( + "armada.log_manager.KubernetesPodLogManager.fetch_container_logs" + ) as logs, patch( + "armada.operators.armada.ArmadaOperator.hook", new_callable=default_hook + ) as hook: + yield sleep, logs, hook + + +@pytest.fixture +def context(): + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_ti.try_number = 0 + mock_ti.xcom_pull.return_value = None + + mock_dag = MagicMock() + mock_dag.dag_id = "test_dag_1" + + context = {"ti": mock_ti, "run_id": "test_run_1", "dag": mock_dag} + + return context + + +def operator( + job_request: JobSubmitRequestItem, + deferrable: bool = False, + job_acknowledgement_timeout_s: int = 30, + container_logs: Optional[str] = None, +) -> ArmadaOperator: + operator = ArmadaOperator( + armada_queue=DEFAULT_QUEUE, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + container_logs=container_logs, + deferrable=deferrable, + job_acknowledgement_timeout=job_acknowledgement_timeout_s, + job_request=job_request, + job_set_prefix="prefix-", + lookout_url_template="http://lookout.armadaproject.io/jobs?job_id=", + name="test", + task_id=DEFAULT_TASK_ID, + ) + + return operator + + +def running_job_context( + cluster: str = None, + submit_time: DateTime = DateTime.now(), + job_state: str = JobState.UNKNOWN.name, +) -> RunningJobContext: + return RunningJobContext( + DEFAULT_QUEUE, + DEFAULT_JOB_ID, + DEFAULT_JOB_SET, + submit_time, + cluster, + job_state=job_state, + ) + + +@pytest.mark.parametrize( + "job_states", + [ + [JobState.RUNNING, JobState.SUCCEEDED], + [ + JobState.QUEUED, + JobState.LEASED, + JobState.QUEUED, + JobState.RUNNING, + JobState.SUCCEEDED, + ], + ], + ids=["success", "success - multiple events"], +) +def test_execute(job_states, context): + op = operator(JobSubmitRequestItem()) + + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) for s in job_states + ] + + op.execute(context) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert op.hook.refresh_context.call_count == len(job_states) + + # We're not polling for logs + op.pod_manager.fetch_container_logs.assert_not_called() + + +@patch("pendulum.DateTime.utcnow", return_value=DEFAULT_CURRENT_TIME) +def test_execute_in_deferrable(_, context): + op = operator(JobSubmitRequestItem(), deferrable=True) + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) + for s in [JobState.QUEUED, JobState.QUEUED] + ] + + with pytest.raises(TaskDeferred) as deferred: + op.execute(context) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert deferred.value.timeout == op.execution_timeout + assert deferred.value.trigger == ArmadaPollJobTrigger( + moment=DEFAULT_CURRENT_TIME + timedelta(seconds=op.poll_interval), + context=op.job_context, + channel_args=op.channel_args, + ) + assert deferred.value.method_name == "_trigger_reentry" + + +@pytest.mark.parametrize( + "terminal_state", + [JobState.FAILED, JobState.PREEMPTED, JobState.CANCELLED], + ids=["failed", "preempted", "cancelled"], +) +def test_execute_fail(terminal_state, context): + op = operator(JobSubmitRequestItem()) + + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) + for s in [JobState.RUNNING, terminal_state] + ] + + with pytest.raises(AirflowException) as exec_info: + op.execute(context) + + # Error message contain terminal state and job id + assert DEFAULT_JOB_ID in str(exec_info) + assert terminal_state.name in str(exec_info) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert op.hook.refresh_context.call_count == 2 + + # We're not polling for logs + op.pod_manager.fetch_container_logs.assert_not_called() + + +def test_on_kill_terminates_running_job(): + op = operator(JobSubmitRequestItem()) + job_context = running_job_context() + op.job_context = job_context + + op.on_kill() + op.on_kill() + + # We ensure we only try to cancel job once. + op.hook.cancel_job.assert_called_once_with(job_context) + + +def test_not_acknowledged_within_timeout_terminates_running_job(context): + job_context = running_job_context() + op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1) + op.hook.refresh_context.return_value = job_context + + with pytest.raises(AirflowException) as exec_info: + op.execute(context) + + # Error message contain terminal state and job id + assert DEFAULT_JOB_ID in str(exec_info) + assert JobState.CANCELLED.name in str(exec_info) + + # We also cancel already submitted job + op.hook.cancel_job.assert_called_once_with(job_context) + + +def test_polls_for_logs(context): + op = operator( + JobSubmitRequestItem(namespace="namespace-1"), container_logs="alpine" + ) + op.execute(context) + + # We polled logs as expected. + op.pod_manager.fetch_container_logs.assert_called_once_with( + k8s_context="cluster-1", + namespace="namespace-1", + pod="armada-test_job-0", + container="alpine", + since_time=None, + ) + + +def test_publishes_xcom_state(context): + op = operator(JobSubmitRequestItem()) + op.execute(context) + + lookout_url = f"http://lookout.armadaproject.io/jobs?job_id={DEFAULT_JOB_ID}" + context["ti"].xcom_push.assert_called_once_with( + key="0", + value={ + "armada_job_id": DEFAULT_JOB_ID, + "armada_job_set_id": DEFAULT_JOB_SET, + "armada_lookout_url": lookout_url, + "armada_queue": DEFAULT_QUEUE, + }, + ) + + +def test_reattaches_to_running_job(context): + op = operator(JobSubmitRequestItem()) + context["ti"].xcom_pull.return_value = { + "armada_job_id": DEFAULT_JOB_ID, + "armada_job_set_id": DEFAULT_JOB_SET, + "armada_queue": DEFAULT_QUEUE, + } + + op.execute(context) + + assert op.job_context == running_job_context( + job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER + ) + op.hook.submit_job.assert_not_called() + + +@pytest.mark.skip("TODO") +def test_templates_job_request_item(): + pass diff --git a/third_party/airflow/test/unit/test_hooks.py b/third_party/airflow/test/unit/test_hooks.py new file mode 100644 index 00000000000..0a2e1ba2e11 --- /dev/null +++ b/third_party/airflow/test/unit/test_hooks.py @@ -0,0 +1,16 @@ +import pytest + + +@pytest.mark.skip("TODO") +def test_submits_job_using_armada_client(): + pass + + +@pytest.mark.skip("TODO") +def test_cancels_job_using_armada_client(): + pass + + +@pytest.mark.skip("TODO") +def test_updates_job_context(): + pass diff --git a/third_party/airflow/test/unit/test_model.py b/third_party/airflow/test/unit/test_model.py new file mode 100644 index 00000000000..906b7315ad9 --- /dev/null +++ b/third_party/airflow/test/unit/test_model.py @@ -0,0 +1,33 @@ +import grpc +from airflow.serialization.serde import deserialize, serialize +from armada.model import GrpcChannelArgs, RunningJobContext +from armada_client.typings import JobState +from pendulum import DateTime + + +def test_roundtrip_running_job_context(): + context = RunningJobContext( + "queue_123", + "job_id_123", + "job_set_id_123", + DateTime.utcnow(), + "cluster-1.armada.localhost", + DateTime.utcnow().add(minutes=-2), + JobState.RUNNING.name, + ) + + result = deserialize(serialize(context)) + assert context == result + assert JobState.RUNNING == result.state + + +def test_roundtrip_grpc_channel_args(): + channel_args = GrpcChannelArgs( + "armada-api.localhost", + [("key-1", 10), ("key-2", "value-2")], + grpc.Compression.NoCompression, + None, + ) + + result = deserialize(serialize(channel_args)) + assert channel_args == result diff --git a/third_party/airflow/test/unit/test_triggers.py b/third_party/airflow/test/unit/test_triggers.py new file mode 100644 index 00000000000..bdd15333caa --- /dev/null +++ b/third_party/airflow/test/unit/test_triggers.py @@ -0,0 +1,16 @@ +import pytest + + +@pytest.mark.skip("TODO") +def test_yields_with_context(): + pass + + +@pytest.mark.skip("TODO") +def test_cancels_running_job_when_task_is_cancelled(): + pass + + +@pytest.mark.skip("TODO") +def test_do_not_cancels_running_job_when_trigger_is_suspended(): + pass diff --git a/third_party/airflow/tox.ini b/third_party/airflow/tox.ini index 09dd8ce15ea..ed457e94d70 100644 --- a/third_party/airflow/tox.ini +++ b/third_party/airflow/tox.ini @@ -13,7 +13,7 @@ allowlist_externals = find xargs commands = - coverage run -m unittest discover + coverage run -m pytest test/unit/ coverage xml # This executes the dag files in examples but really only checks for imports and python errors bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3"