diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 227dfeaa1b1..29709313ac9 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -36,19 +36,19 @@ and handles job cancellation if the Airflow task is killed. * **job_request** (*JobSubmitRequestItem*) – - * **job_set_prefix** (*str** | **None*) – + * **job_set_prefix** (*Optional**[**str**]*) – - * **lookout_url_template** (*str** | **None*) – + * **lookout_url_template** (*Optional**[**str**]*) – * **poll_interval** (*int*) – - * **container_logs** (*str** | **None*) – + * **container_logs** (*Optional**[**str**]*) – - * **k8s_token_retriever** (*TokenRetriever** | **None*) – + * **k8s_token_retriever** (*Optional**[**TokenRetriever**]*) – * **deferrable** (*bool*) – @@ -89,19 +89,7 @@ operator needs to be cleaned up, or it will leave ghost processes behind. -#### pod_manager(k8s_context) - -* **Parameters** - - **k8s_context** (*str*) – - - - -* **Return type** - - *PodLogManager* - - +#### _property_ pod_manager(_: KubernetesPodLogManage_ ) #### render_template_fields(context, jinja_env=None) Template all attributes listed in self.template_fields. @@ -147,7 +135,7 @@ Initializes a new ArmadaOperator. * **job_request** (*JobSubmitRequestItem*) – The job to be submitted to Armada. - * **job_set_prefix** (*Optional**[**str**]*) – A string to prepend to the jobSet name + * **job_set_prefix** (*Optional**[**str**]*) – A string to prepend to the jobSet name. * **lookout_url_template** – Template for creating lookout links. If not specified @@ -170,95 +158,9 @@ acknowledged by Armada. :type job_acknowledgement_timeout: int :param kwargs: Additional keyword arguments to pass to the BaseOperator. -## armada.triggers.armada module - - -### _class_ armada.triggers.armada.ArmadaTrigger(job_id, armada_queue, job_set_id, poll_interval, tracking_message, job_acknowledgement_timeout, job_request_namespace, channel_args=None, channel_args_details=None, container_logs=None, k8s_token_retriever=None, k8s_token_retriever_details=None, last_log_time=None) -Bases: `BaseTrigger` - -An Airflow Trigger that can asynchronously manage an Armada job. - - -* **Parameters** - - - * **job_id** (*str*) – - - - * **armada_queue** (*str*) – - - - * **job_set_id** (*str*) – - - - * **poll_interval** (*int*) – - - - * **tracking_message** (*str*) – - - - * **job_acknowledgement_timeout** (*int*) – - - - * **job_request_namespace** (*str*) – - - - * **channel_args** (*GrpcChannelArgs*) – - - - * **channel_args_details** (*Dict**[**str**, **Any**]*) – - - - * **container_logs** (*str** | **None*) – - - - * **k8s_token_retriever** (*TokenRetriever** | **None*) – - - - * **k8s_token_retriever_details** (*Tuple**[**str**, **Dict**[**str**, **Any**]**] **| **None*) – - - - * **last_log_time** (*DateTime** | **None*) – - - - -#### _property_ client(_: ArmadaAsyncIOClien_ ) - -#### pod_manager(k8s_context) - -* **Parameters** - - **k8s_context** (*str*) – - - - -* **Return type** - - *PodLogManagerAsync* - - - -#### _async_ run() -Run the Trigger Asynchronously. This will poll Armada until the Job reaches a -terminal state - - -* **Return type** - - *AsyncIterator*[*TriggerEvent*] - - - -#### serialize() -Serialises the state of this Trigger. -When the Trigger is re-hydrated, these values will be passed to init() as kwargs -:return: - - -* **Return type** - - tuple +### armada.operators.armada.log_exceptions(method) +## armada.triggers.armada module ## armada.auth module diff --git a/third_party/airflow/armada/auth.py b/third_party/airflow/armada/auth.py index ca90b521ecf..16275dbc343 100644 --- a/third_party/airflow/armada/auth.py +++ b/third_party/airflow/armada/auth.py @@ -1,5 +1,4 @@ -from typing import Dict, Any, Tuple, Protocol - +from typing import Any, Dict, Protocol, Tuple """ We use this interface for objects fetching Kubernetes auth tokens. Since it's used within the Trigger, it must be serialisable.""" diff --git a/third_party/airflow/armada/log_manager.py b/third_party/airflow/armada/log_manager.py new file mode 100644 index 00000000000..0a94ddeabbb --- /dev/null +++ b/third_party/airflow/armada/log_manager.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import math +import threading +from http.client import HTTPResponse +from typing import Dict, List, Optional, Tuple, cast + +import pendulum +from airflow.utils.log.logging_mixin import LoggingMixin +from armada.auth import TokenRetriever +from kubernetes import client, config +from pendulum import DateTime +from pendulum.parsing.exceptions import ParserError +from urllib3.exceptions import HTTPError + + +class KubernetesPodLogManager(LoggingMixin): + """Monitor logs of Kubernetes pods asynchronously.""" + + CLIENTS_LOCK = threading.Lock() + CLIENTS: Dict[str, client.CoreV1Api] = {} + + def __init__( + self, + token_retriever: Optional[TokenRetriever] = None, + ): + """ + Create PodLogManger. + :param token_retriever: Retrieves auth tokens + """ + super().__init__() + self._token_retriever = token_retriever + + def _k8s_client(self, k8s_context) -> client.CoreV1Api: + """ + K8S Clients are expensive to initialize (especially loading configuration). + We cache them per context in class level cache. + + Access to this method can be from multiple-threads. + """ + if k8s_context not in KubernetesPodLogManager.CLIENTS: + with KubernetesPodLogManager.CLIENTS_LOCK: + configuration = client.Configuration() + config.load_kube_config( + client_configuration=configuration, context=k8s_context + ) + k8s_client = client.CoreV1Api( + api_client=client.ApiClient(configuration=configuration) + ) + k8s_client.api_client.configuration.api_key_prefix["authorization"] = ( + "Bearer" + ) + KubernetesPodLogManager.CLIENTS[k8s_context] = k8s_client + return KubernetesPodLogManager.CLIENTS[k8s_context] + + def _with_bearer_auth(self, client): + client.api_client.configuration.api_key["authorization"] = ( + self._token_retriever.get_token() + ) + + def fetch_container_logs( + self, + *, + k8s_context: str, + namespace: str, + pod: str, + container: str, + since_time: Optional[DateTime], + ) -> Optional[DateTime]: + """ + Fetches container logs, do not follow container logs. + """ + client = self._k8s_client(k8s_context) + self._with_bearer_auth(client) + since_seconds = ( + math.ceil((pendulum.now() - since_time).total_seconds()) + if since_time + else None + ) + try: + logs = client.read_namespaced_pod_log( + namespace=namespace, + name=pod, + container=container, + follow=False, + timestamps=True, + since_seconds=since_seconds, + _preload_content=False, + ) + if logs.status == 404: + self.log.warning(f"Unable to fetch logs - pod {pod} has been deleted.") + return since_time + except HTTPError as e: + self.log.exception(f"There was an error reading the kubernetes API: {e}.") + raise + + return self._stream_logs(container, since_time, logs) + + def _stream_logs( + self, container: str, since_time: Optional[DateTime], logs: HTTPResponse + ) -> Optional[DateTime]: + messages: List[str] = [] + message_timestamp = None + try: + chunk = logs.read() + lines = chunk.decode("utf-8", errors="backslashreplace").splitlines() + for raw_line in lines: + line_timestamp, message = self._parse_log_line(raw_line) + + if line_timestamp: # detect new log-line (starts with timestamp) + if since_time and line_timestamp <= since_time: + continue + self._log_container_message(container, messages) + messages.clear() + message_timestamp = line_timestamp + messages.append(message) + except HTTPError as e: + self.log.warning( + f"Reading of logs interrupted for container {container} with error {e}." + ) + + self._log_container_message(container, messages) + return message_timestamp + + def _log_container_message(self, container: str, messages: List[str]): + if messages: + self.log.info("[%s] %s", container, "\n".join(messages)) + + def _parse_log_line(self, line: bytes) -> Tuple[DateTime | None, str]: + """ + Parse K8s log line and returns the final state. + + :param line: k8s log line + :return: timestamp and log message + """ + timestamp, sep, message = line.strip().partition(" ") + if not sep: + return None, line + try: + last_log_time = cast(DateTime, pendulum.parse(timestamp)) + except ParserError: + return None, line + return last_log_time, message diff --git a/third_party/airflow/armada/logs/log_consumer.py b/third_party/airflow/armada/logs/log_consumer.py deleted file mode 100644 index 8fd8c31d3ef..00000000000 --- a/third_party/airflow/armada/logs/log_consumer.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2016-2024 The Apache Software Foundation -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import queue -from datetime import timedelta -from http.client import HTTPResponse -from typing import Generator, TYPE_CHECKING, Callable, Awaitable, List - -from aiohttp.client_exceptions import ClientResponse -from airflow.utils.timezone import utcnow -from kubernetes.client import V1Pod -from kubernetes_asyncio.client import V1Pod as aio_V1Pod - -from armada.logs.utils import container_is_running, get_container_status - -if TYPE_CHECKING: - from urllib3.response import HTTPResponse # noqa: F811 - - -class PodLogsConsumerAsync: - """ - Responsible for pulling pod logs from a stream asynchronously, checking the - container status before reading data. - - This class contains a workaround for the issue - https://github.com/apache/airflow/issues/23497. - - :param response: HTTP response with logs - :param pod: Pod instance from Kubernetes client - :param read_pod_async: Callable returning a pod object that can be awaited on, - given (pod name, namespace) as arguments - :param container_name: Name of the container that we're reading logs from - :param post_termination_timeout: (Optional) The period of time in seconds - representing for how long time - logs are available after the container termination. - :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. - The container status is cached to reduce API calls. - - :meta private: - """ - - def __init__( - self, - response: ClientResponse, - pod_name: str, - namespace: str, - read_pod_async: Callable[[str, str], Awaitable[aio_V1Pod]], - container_name: str, - post_termination_timeout: int = 120, - read_pod_cache_timeout: int = 120, - ): - self.response = response - self.pod_name = pod_name - self.namespace = namespace - self._read_pod_async = read_pod_async - self.container_name = container_name - self.post_termination_timeout = post_termination_timeout - self.last_read_pod_at = None - self.read_pod_cache = None - self.read_pod_cache_timeout = read_pod_cache_timeout - self.log_queue = queue.Queue() - - def __aiter__(self): - return self - - async def __anext__(self): - r"""Yield log items divided by the '\n' symbol.""" - if not self.log_queue.empty(): - return self.log_queue.get() - - incomplete_log_item: List[bytes] = [] - if await self.logs_available(): - async for data_chunk in self.response.content: - if b"\n" in data_chunk: - log_items = data_chunk.split(b"\n") - for x in self._extract_log_items(incomplete_log_item, log_items): - if x is not None: - self.log_queue.put(x) - incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) - else: - incomplete_log_item.append(data_chunk) - if not await self.logs_available(): - break - else: - self.response.close() - raise StopAsyncIteration - if incomplete_log_item: - item = b"".join(incomplete_log_item) - if item is not None: - self.log_queue.put(item) - - # Prevents method from returning None - if not self.log_queue.empty(): - return self.log_queue.get() - - self.response.close() - raise StopAsyncIteration - - @staticmethod - def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): - yield b"".join(incomplete_log_item) + log_items[0] + b"\n" - for x in log_items[1:-1]: - yield x + b"\n" - - @staticmethod - def _save_incomplete_log_item(sub_chunk: bytes): - return [sub_chunk] if [sub_chunk] else [] - - async def logs_available(self): - remote_pod = await self.read_pod() - if container_is_running(pod=remote_pod, container_name=self.container_name): - return True - container_status = get_container_status( - pod=remote_pod, container_name=self.container_name - ) - state = container_status.state if container_status else None - terminated = state.terminated if state else None - if terminated: - termination_time = terminated.finished_at - if termination_time: - return ( - termination_time + timedelta(seconds=self.post_termination_timeout) - > utcnow() - ) - return False - - async def read_pod(self): - _now = utcnow() - if ( - self.read_pod_cache is None - or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) - < _now - ): - self.read_pod_cache = await self._read_pod_async( - self.pod_name, self.namespace - ) - self.last_read_pod_at = _now - return self.read_pod_cache - - -class PodLogsConsumer: - """ - Responsible for pulling pod logs from a stream with checking a container status - before reading data. - - This class is a workaround for the issue - https://github.com/apache/airflow/issues/23497. - - :param response: HTTP response with logs - :param pod: Pod instance from Kubernetes client - :param read_pod: Callable returning a pod object given (pod name, namespace) as - arguments - :param container_name: Name of the container that we're reading logs from - :param post_termination_timeout: (Optional) The period of time in seconds - representing for how long time - logs are available after the container termination. - :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. - The container status is cached to reduce API calls. - - :meta private: - """ - - def __init__( - self, - response: HTTPResponse, - pod_name: str, - namespace: str, - read_pod: Callable[[str, str], V1Pod], - container_name: str, - post_termination_timeout: int = 120, - read_pod_cache_timeout: int = 120, - ): - self.response = response - self.pod_name = pod_name - self.namespace = namespace - self._read_pod = read_pod - self.container_name = container_name - self.post_termination_timeout = post_termination_timeout - self.last_read_pod_at = None - self.read_pod_cache = None - self.read_pod_cache_timeout = read_pod_cache_timeout - - def __iter__(self) -> Generator[bytes, None, None]: - r"""Yield log items divided by the '\n' symbol.""" - incomplete_log_item: List[bytes] = [] - if self.logs_available(): - for data_chunk in self.response.stream(amt=None, decode_content=True): - if b"\n" in data_chunk: - log_items = data_chunk.split(b"\n") - yield from self._extract_log_items(incomplete_log_item, log_items) - incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) - else: - incomplete_log_item.append(data_chunk) - if not self.logs_available(): - break - if incomplete_log_item: - yield b"".join(incomplete_log_item) - - @staticmethod - def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): - yield b"".join(incomplete_log_item) + log_items[0] + b"\n" - for x in log_items[1:-1]: - yield x + b"\n" - - @staticmethod - def _save_incomplete_log_item(sub_chunk: bytes): - return [sub_chunk] if [sub_chunk] else [] - - def logs_available(self): - remote_pod = self.read_pod() - if container_is_running(pod=remote_pod, container_name=self.container_name): - return True - container_status = get_container_status( - pod=remote_pod, container_name=self.container_name - ) - state = container_status.state if container_status else None - terminated = state.terminated if state else None - if terminated: - termination_time = terminated.finished_at - if termination_time: - return ( - termination_time + timedelta(seconds=self.post_termination_timeout) - > utcnow() - ) - return False - - def read_pod(self): - _now = utcnow() - if ( - self.read_pod_cache is None - or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) - < _now - ): - self.read_pod_cache = self._read_pod(self.pod_name, self.namespace) - self.last_read_pod_at = _now - return self.read_pod_cache diff --git a/third_party/airflow/armada/logs/pod_log_manager.py b/third_party/airflow/armada/logs/pod_log_manager.py deleted file mode 100644 index 20e8e51c852..00000000000 --- a/third_party/airflow/armada/logs/pod_log_manager.py +++ /dev/null @@ -1,550 +0,0 @@ -# Copyright 2016-2024 The Apache Software Foundation -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import asyncio -import math -import time -from dataclasses import dataclass -from functools import cached_property -from typing import TYPE_CHECKING, cast, Optional - -import pendulum -import tenacity -from kubernetes import client, watch, config -from kubernetes_asyncio import client as async_client, config as async_config -from kubernetes.client.rest import ApiException -from pendulum import DateTime -from pendulum.parsing.exceptions import ParserError -from urllib3.exceptions import HTTPError as BaseHTTPError - -from airflow.exceptions import AirflowException -from airflow.utils.log.logging_mixin import LoggingMixin - -from armada.auth import TokenRetriever -from armada.logs.log_consumer import PodLogsConsumer, PodLogsConsumerAsync -from armada.logs.utils import container_is_running - -if TYPE_CHECKING: - from kubernetes.client.models.v1_pod import V1Pod - - -class PodPhase: - """ - Possible pod phases. - - See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase. - """ - - PENDING = "Pending" - RUNNING = "Running" - FAILED = "Failed" - SUCCEEDED = "Succeeded" - - terminal_states = {FAILED, SUCCEEDED} - - -@dataclass -class PodLoggingStatus: - """Return the status of the pod and last log time when exiting from - `fetch_container_logs`.""" - - running: bool - last_log_time: DateTime | None - - -class PodLogManagerAsync(LoggingMixin): - """Monitor logs of Kubernetes pods asynchronously.""" - - def __init__( - self, - k8s_context: str, - token_retriever: Optional[TokenRetriever] = None, - ): - """ - Create the launcher. - - :param k8s_context: kubernetes context - :param token_retriever: Retrieves auth tokens - """ - super().__init__() - self._k8s_context = k8s_context - self._watch = watch.Watch() - self._k8s_client = None - self._token_retriever = token_retriever - - async def _refresh_k8s_auth_token(self, interval=60 * 5): - if self._token_retriever is not None: - while True: - await asyncio.sleep(interval) - self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( - f"Bearer {self._token_retriever.get_token()}" - ) - - async def k8s_client(self) -> async_client: - await async_config.load_kube_config(context=self._k8s_context) - asyncio.create_task(self._refresh_k8s_auth_token()) - return async_client.CoreV1Api() - - async def fetch_container_logs( - self, - pod_name: str, - namespace: str, - container_name: str, - *, - follow=False, - since_time: DateTime | None = None, - post_termination_timeout: int = 120, - ) -> PodLoggingStatus: - """ - Follow the logs of container and stream to airflow logging. Doesn't block whilst - logs are being fetched. - - Returns when container exits. - - Between when the pod starts and logs being available, there might be a delay due - to CSR not approved - and signed yet. In such situation, ApiException is thrown. This is why we are - retrying on this - specific exception. - """ - # Can't await in constructor, so instantiating here - if self._k8s_client is None: - self._k8s_client = await self.k8s_client() - - @tenacity.retry( - retry=tenacity.retry_if_exception_type(ApiException), - stop=tenacity.stop_after_attempt(10), - wait=tenacity.wait_fixed(1), - ) - async def consume_logs( - *, - since_time: DateTime | None = None, - follow: bool = True, - logs: PodLogsConsumerAsync | None, - ) -> tuple[DateTime | None, PodLogsConsumerAsync | None]: - """ - Try to follow container logs until container completes. - - For a long-running container, sometimes the log read may be interrupted - Such errors of this kind are suppressed. - - Returns the last timestamp observed in logs. - """ - last_captured_timestamp = None - try: - logs = await self._read_pod_logs( - pod_name=pod_name, - namespace=namespace, - container_name=container_name, - timestamps=True, - since_seconds=( - math.ceil((pendulum.now() - since_time).total_seconds()) - if since_time - else None - ), - follow=follow, - post_termination_timeout=post_termination_timeout, - ) - message_to_log = None - message_timestamp = None - progress_callback_lines = [] - try: - async for raw_line in logs: - line = raw_line.decode("utf-8", errors="backslashreplace") - line_timestamp, message = self._parse_log_line(line) - if line_timestamp: # detect new log line - if message_to_log is None: # first line in the log - message_to_log = message - message_timestamp = line_timestamp - progress_callback_lines.append(line) - else: # previous log line is complete - self.log.info("[%s] %s", container_name, message_to_log) - last_captured_timestamp = message_timestamp - message_to_log = message - message_timestamp = line_timestamp - progress_callback_lines = [line] - else: # continuation of the previous log line - message_to_log = f"{message_to_log}\n{message}" - progress_callback_lines.append(line) - finally: - if message_to_log is not None: - self.log.info("[%s] %s", container_name, message_to_log) - last_captured_timestamp = message_timestamp - except BaseHTTPError as e: - self.log.warning( - "Reading of logs interrupted for container %r with error %r; will " - "retry. " - "Set log level to DEBUG for traceback.", - container_name, - e, - ) - self.log.debug( - "Traceback for interrupted logs read for pod %r", - pod_name, - exc_info=True, - ) - return last_captured_timestamp or since_time, logs - - # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to - # loop as we do here. But in a long-running process we might temporarily lose - # connectivity. - # So the looping logic is there to let us resume following the logs. - logs = None - last_log_time = since_time - while True: - last_log_time, logs = await consume_logs( - since_time=last_log_time, - follow=follow, - logs=logs, - ) - if not await self._container_is_running_async( - pod_name, namespace, container_name=container_name - ): - return PodLoggingStatus(running=False, last_log_time=last_log_time) - if not follow: - return PodLoggingStatus(running=True, last_log_time=last_log_time) - else: - self.log.warning( - "Pod %s log read interrupted but container %s still running", - pod_name, - container_name, - ) - time.sleep(1) - - def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: - """ - Parse K8s log line and returns the final state. - - :param line: k8s log line - :return: timestamp and log message - """ - timestamp, sep, message = line.strip().partition(" ") - if not sep: - return None, line - try: - last_log_time = cast(DateTime, pendulum.parse(timestamp)) - except ParserError: - return None, line - return last_log_time, message - - async def _container_is_running_async( - self, pod_name: str, namespace: str, container_name: str - ) -> bool: - """Read pod and checks if container is running.""" - remote_pod = await self.read_pod(pod_name, namespace) - return container_is_running(pod=remote_pod, container_name=container_name) - - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True, - ) - async def _read_pod_logs( - self, - pod_name: str, - namespace: str, - container_name: str, - tail_lines: int | None = None, - timestamps: bool = False, - since_seconds: int | None = None, - follow=True, - post_termination_timeout: int = 120, - ) -> PodLogsConsumerAsync: - """Read log from the POD.""" - additional_kwargs = {} - if since_seconds: - additional_kwargs["since_seconds"] = since_seconds - - if tail_lines: - additional_kwargs["tail_lines"] = tail_lines - - try: - logs = await self._k8s_client.read_namespaced_pod_log( - name=pod_name, - namespace=namespace, - container=container_name, - follow=follow, - timestamps=timestamps, - _preload_content=False, - **additional_kwargs, - ) - except BaseHTTPError: - self.log.exception("There was an error reading the kubernetes API.") - raise - - return PodLogsConsumerAsync( - response=logs, - pod_name=pod_name, - namespace=namespace, - read_pod_async=self.read_pod, - container_name=container_name, - post_termination_timeout=post_termination_timeout, - ) - - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True, - ) - async def read_pod(self, pod_name: str, namespace: str) -> V1Pod: - """Read POD information.""" - try: - return await self._k8s_client.read_namespaced_pod(pod_name, namespace) - except BaseHTTPError as e: - raise AirflowException( - f"There was an error reading the kubernetes API: {e}" - ) - - -class PodLogManager(LoggingMixin): - """Monitor logs of Kubernetes pods.""" - - def __init__( - self, k8s_context: str, token_retriever: Optional[TokenRetriever] = None - ): - """ - Create the launcher. - - :param k8s_context: kubernetes context - :param token_retriever: Retrieves auth tokens - """ - super().__init__() - self._k8s_context = k8s_context - self._watch = watch.Watch() - self._token_retriever = token_retriever - - def _refresh_k8s_auth_token(self): - if self._token_retriever is not None: - self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( - f"Bearer {self._token_retriever.get_token()}" - ) - - @cached_property - def _k8s_client(self) -> client: - config.load_kube_config(context=self._k8s_context) - return client.CoreV1Api() - - def fetch_container_logs( - self, - pod_name: str, - namespace: str, - container_name: str, - *, - follow=False, - since_time: DateTime | None = None, - post_termination_timeout: int = 120, - ) -> PodLoggingStatus: - """ - Follow the logs of container and stream to airflow logging. - - Returns when container exits. - - Between when the pod starts and logs being available, there might be a delay due - to CSR not approved - and signed yet. In such situation, ApiException is thrown. This is why we are - retrying on this - specific exception. - """ - - @tenacity.retry( - retry=tenacity.retry_if_exception_type(ApiException), - stop=tenacity.stop_after_attempt(10), - wait=tenacity.wait_fixed(1), - ) - def consume_logs( - *, - since_time: DateTime | None = None, - follow: bool = True, - logs: PodLogsConsumer | None, - ) -> tuple[DateTime | None, PodLogsConsumer | None]: - """ - Try to follow container logs until container completes. - - For a long-running container, sometimes the log read may be interrupted - Such errors of this kind are suppressed. - - Returns the last timestamp observed in logs. - """ - last_captured_timestamp = None - try: - logs = self._read_pod_logs( - pod_name=pod_name, - namespace=namespace, - container_name=container_name, - timestamps=True, - since_seconds=( - math.ceil((pendulum.now() - since_time).total_seconds()) - if since_time - else None - ), - follow=follow, - post_termination_timeout=post_termination_timeout, - ) - message_to_log = None - message_timestamp = None - progress_callback_lines = [] - try: - for raw_line in logs: - line = raw_line.decode("utf-8", errors="backslashreplace") - line_timestamp, message = self._parse_log_line(line) - if line_timestamp: # detect new log line - if message_to_log is None: # first line in the log - message_to_log = message - message_timestamp = line_timestamp - progress_callback_lines.append(line) - else: # previous log line is complete - self.log.info("[%s] %s", container_name, message_to_log) - last_captured_timestamp = message_timestamp - message_to_log = message - message_timestamp = line_timestamp - progress_callback_lines = [line] - else: # continuation of the previous log line - message_to_log = f"{message_to_log}\n{message}" - progress_callback_lines.append(line) - finally: - if message_to_log is not None: - self.log.info("[%s] %s", container_name, message_to_log) - last_captured_timestamp = message_timestamp - except BaseHTTPError as e: - self.log.warning( - "Reading of logs interrupted for container %r with error %r; will " - "retry. " - "Set log level to DEBUG for traceback.", - container_name, - e, - ) - self.log.debug( - "Traceback for interrupted logs read for pod %r", - pod_name, - exc_info=True, - ) - return last_captured_timestamp or since_time, logs - - # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to - # loop as we do here. But in a long-running process we might temporarily lose - # connectivity. - # So the looping logic is there to let us resume following the logs. - logs = None - last_log_time = since_time - while True: - last_log_time, logs = consume_logs( - since_time=last_log_time, - follow=follow, - logs=logs, - ) - if not self._container_is_running( - pod_name, namespace, container_name=container_name - ): - return PodLoggingStatus(running=False, last_log_time=last_log_time) - if not follow: - return PodLoggingStatus(running=True, last_log_time=last_log_time) - else: - self.log.warning( - "Pod %s log read interrupted but container %s still running", - pod_name, - container_name, - ) - time.sleep(1) - self._refresh_k8s_auth_token() - - def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: - """ - Parse K8s log line and returns the final state. - - :param line: k8s log line - :return: timestamp and log message - """ - timestamp, sep, message = line.strip().partition(" ") - if not sep: - return None, line - try: - last_log_time = cast(DateTime, pendulum.parse(timestamp)) - except ParserError: - return None, line - return last_log_time, message - - def _container_is_running( - self, pod_name: str, namespace: str, container_name: str - ) -> bool: - """Read pod and checks if container is running.""" - remote_pod = self.read_pod(pod_name, namespace) - return container_is_running(pod=remote_pod, container_name=container_name) - - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True, - ) - def _read_pod_logs( - self, - pod_name: str, - namespace: str, - container_name: str, - tail_lines: int | None = None, - timestamps: bool = False, - since_seconds: int | None = None, - follow=True, - post_termination_timeout: int = 120, - ) -> PodLogsConsumer: - """Read log from the POD.""" - additional_kwargs = {} - if since_seconds: - additional_kwargs["since_seconds"] = since_seconds - - if tail_lines: - additional_kwargs["tail_lines"] = tail_lines - - try: - logs = self._k8s_client.read_namespaced_pod_log( - name=pod_name, - namespace=namespace, - container=container_name, - follow=follow, - timestamps=timestamps, - _preload_content=False, - **additional_kwargs, - ) - except BaseHTTPError: - self.log.exception("There was an error reading the kubernetes API.") - raise - - return PodLogsConsumer( - response=logs, - pod_name=pod_name, - namespace=namespace, - read_pod=self.read_pod, - container_name=container_name, - post_termination_timeout=post_termination_timeout, - ) - - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(), - reraise=True, - ) - def read_pod(self, pod_name: str, namespace: str) -> V1Pod: - """Read POD information.""" - try: - return self._k8s_client.read_namespaced_pod(pod_name, namespace) - except BaseHTTPError as e: - raise AirflowException( - f"There was an error reading the kubernetes API: {e}" - ) diff --git a/third_party/airflow/armada/logs/utils.py b/third_party/airflow/armada/logs/utils.py deleted file mode 100644 index ade71ba5fbe..00000000000 --- a/third_party/airflow/armada/logs/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016-2024 The Apache Software Foundation -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import TYPE_CHECKING - -from kubernetes.client import V1Pod, V1ContainerStatus - -if TYPE_CHECKING: - from kubernetes.client.models.v1_container_status import ( # noqa: F811 - V1ContainerStatus, - ) - from kubernetes.client.models.v1_pod import V1Pod # noqa: F811 - - -def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus: - """Retrieve container status.""" - container_statuses = pod.status.container_statuses if pod and pod.status else None - if container_statuses: - # In general the variable container_statuses can store multiple items matching - # different containers. - # The following generator expression yields all items that have name equal to - # the container_name. - # The function next() here calls the generator to get only the first value. If - # there's nothing found - # then None is returned. - return next((x for x in container_statuses if x.name == container_name), None) - return None - - -def container_is_running(pod: V1Pod, container_name: str) -> bool: - """ - Examine V1Pod ``pod`` to determine whether ``container_name`` is running. - - If that container is present and running, returns True. Returns False otherwise. - """ - container_status = get_container_status(pod, container_name) - if not container_status: - return False - return container_status.state.running is not None diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py index 80e6e0d0a77..00b9ab59800 100644 --- a/third_party/airflow/armada/model.py +++ b/third_party/airflow/armada/model.py @@ -1,9 +1,8 @@ import importlib -from typing import Tuple, Any, Optional, Sequence, Dict +from typing import Any, Dict, Optional, Sequence, Tuple import grpc - """ This class exists so that we can retain our connection to the Armada Query API when using the deferrable Armada Airflow Operator. Airflow requires any state within deferrable operators be serialisable, unfortunately grpc.Channel isn't diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index cb9fd361c27..7e365417ed3 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -15,29 +15,127 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import asyncio +import datetime +import functools import os +import threading import time -from functools import lru_cache, cached_property -from typing import Optional, Sequence, Any, Dict +from dataclasses import dataclass +from functools import cached_property +from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple import jinja2 from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator - +from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.context import Context from airflow.utils.log.logging_mixin import LoggingMixin +from armada.auth import TokenRetriever +from armada.log_manager import KubernetesPodLogManager +from armada.model import GrpcChannelArgs from armada_client.armada.job_pb2 import JobRunDetails -from armada_client.typings import JobState from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.client import ArmadaClient +from armada_client.typings import JobState from google.protobuf.json_format import MessageToDict, ParseDict +from pendulum import DateTime -from armada_client.client import ArmadaClient -from armada.auth import TokenRetriever -from armada.logs.pod_log_manager import PodLogManager -from armada.model import GrpcChannelArgs -from armada.triggers.armada import ArmadaTrigger + +def log_exceptions(method): + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + try: + return method(self, *args, **kwargs) + except Exception as e: + if hasattr(self, "log") and hasattr(self.log, "error"): + self.log.error(f"Exception in {method.__name__}: {e}") + raise + + return wrapper + + +@dataclass(frozen=False) +class _RunningJobContext: + armada_queue: str + job_set_id: str + job_id: str + state: JobState = JobState.UNKNOWN + start_time: DateTime = DateTime.utcnow() + cluster: Optional[str] = None + last_log_time: Optional[DateTime] = None + + def serialize(self) -> tuple[str, Dict[str, Any]]: + return ( + "armada.operators.armada._RunningJobContext", + { + "armada_queue": self.armada_queue, + "job_set_id": self.job_set_id, + "job_id": self.job_id, + "state": self.state.value, + "start_time": self.start_time, + "cluster": self.cluster, + "last_log_time": self.last_log_time, + }, + ) + + def from_payload(payload: Dict[str, Any]) -> _RunningJobContext: + return _RunningJobContext( + armada_queue=payload["armada_queue"], + job_set_id=payload["job_set_id"], + job_id=payload["job_id"], + state=JobState(payload["state"]), + start_time=payload["start_time"], + cluster=payload["cluster"], + last_log_time=payload["last_log_time"], + ) + + +class _ArmadaPollJobTrigger(BaseTrigger): + def __init__(self, moment: datetime.timedelta, context: _RunningJobContext) -> None: + super().__init__() + self.moment = moment + self.context = context + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "armada.operators.armada._ArmadaPollJobTrigger", + {"moment": self.moment, "context": self.context.serialize()}, + ) + + def __eq__(self, value: object) -> bool: + if not isinstance(value, _ArmadaPollJobTrigger): + return False + return self.moment == value.moment and self.context == value.context + + async def run(self) -> AsyncIterator[TriggerEvent]: + while self.moment > DateTime.utcnow(): + await asyncio.sleep(1) + yield TriggerEvent(self.context) + + +class _ArmadaClientFactory: + CLIENTS_LOCK = threading.Lock() + CLIENTS: Dict[str, ArmadaClient] = {} + + @staticmethod + def client_for(args: GrpcChannelArgs) -> ArmadaClient: + """ + Armada clients, maintain GRPC connection to Armada API. + We cache them per channel args config in class level cache. + + Access to this method can be from multiple-threads. + """ + channel_args_key = str(args.serialize()) + with _ArmadaClientFactory.CLIENTS_LOCK: + if channel_args_key not in _ArmadaClientFactory.CLIENTS: + _ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( + channel=args.channel() + ) + return _ArmadaClientFactory.CLIENTS[channel_args_key] class ArmadaOperator(BaseOperator, LoggingMixin): @@ -61,7 +159,7 @@ class ArmadaOperator(BaseOperator, LoggingMixin): :type armada_queue: str :param job_request: The job to be submitted to Armada. :type job_request: JobSubmitRequestItem -:param job_set_prefix: A string to prepend to the jobSet name +:param job_set_prefix: A string to prepend to the jobSet name. :type job_set_prefix: Optional[str] :param lookout_url_template: Template for creating lookout links. If not specified then no tracking information will be logged. @@ -94,7 +192,7 @@ def __init__( container_logs: Optional[str] = None, k8s_token_retriever: Optional[TokenRetriever] = None, deferrable: bool = conf.getboolean( - "operators", "default_deferrable", fallback=False + "operators", "default_deferrable", fallback=True ), job_acknowledgement_timeout: int = 5 * 60, **kwargs, @@ -104,6 +202,7 @@ def __init__( self.channel_args = channel_args self.armada_queue = armada_queue self.job_request = job_request + self.job_set_id = None self.job_set_prefix = job_set_prefix self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval @@ -111,8 +210,7 @@ def __init__( self.k8s_token_retriever = k8s_token_retriever self.deferrable = deferrable self.job_acknowledgement_timeout = job_acknowledgement_timeout - self.job_id = None - self.job_set_id = None + self.job_context = None if self.container_logs and self.k8s_token_retriever is None: self.log.warning( @@ -120,6 +218,7 @@ def __init__( "logs from Kubernetes" ) + @log_exceptions def execute(self, context) -> None: """ Submits the job to Armada and polls for completion. @@ -130,45 +229,33 @@ def execute(self, context) -> None: # We take the job_set_id from Airflow's run_id. This means that all jobs in the # dag will be in the same jobset. self.job_set_id = f"{self.job_set_prefix}{context['run_id']}" + self._annotate_job_request(context, self.job_request) # Submit job or reattach to previously submitted job. We always do this # synchronously. - self.job_id = self._reattach_or_submit_job( + job_id = self._reattach_or_submit_job( context, self.armada_queue, self.job_set_id, self.job_request ) # Wait until finished + self.job_context = _RunningJobContext( + self.armada_queue, self.job_set_id, job_id, start_time=DateTime.utcnow() + ) if self.deferrable: - self.defer( - timeout=self.execution_timeout, - trigger=ArmadaTrigger( - job_id=self.job_id, - armada_queue=self.armada_queue, - job_set_id=self.job_set_id, - channel_args=self.channel_args, - poll_interval=self.poll_interval, - tracking_message=self._trigger_tracking_message(), - job_acknowledgement_timeout=self.job_acknowledgement_timeout, - container_logs=self.container_logs, - k8s_token_retriever=self.k8s_token_retriever, - job_request_namespace=self.job_request.namespace, - ), - method_name="_execute_complete", - ) + self._deffered_yield(self.job_context) else: - self._poll_for_termination(self._trigger_tracking_message()) + self._poll_for_termination(self.job_context) @cached_property def client(self) -> ArmadaClient: - return ArmadaClient(channel=self.channel_args.channel()) + return _ArmadaClientFactory.client_for(self.channel_args) - @lru_cache(maxsize=None) - def pod_manager(self, k8s_context: str) -> PodLogManager: - return PodLogManager( - k8s_context=k8s_context, token_retriever=self.k8s_token_retriever - ) + @cached_property + def pod_manager(self) -> KubernetesPodLogManager: + return KubernetesPodLogManager(token_retriever=self.k8s_token_retriever) + @log_exceptions def render_template_fields( self, context: Context, @@ -189,40 +276,59 @@ def render_template_fields( super().render_template_fields(context, jinja_env) self.job_request = ParseDict(self.job_request, JobSubmitRequestItem()) - def _cancel_job(self) -> None: + def _cancel_job(self, job_context) -> None: try: result = self.client.cancel_jobs( - queue=self.armada_queue, - job_set_id=self.job_set_id, - job_id=self.job_id, + queue=job_context.armada_queue, + job_set_id=job_context.job_set_id, + job_id=job_context.job_id, ) if len(list(result.cancelled_ids)) > 0: self.log.info(f"Cancelled job with id {result.cancelled_ids}") else: - self.log.warning(f"Failed to cancel job with id {self.job_id}") + self.log.warning(f"Failed to cancel job with id {job_context.job_id}") except Exception as e: - self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") + self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") def on_kill(self) -> None: - if self.job_id is not None: + if self.job_context is not None: self.log.info( - f"on_kill called, cancelling job with id {self.job_id} in queue " - f"{self.armada_queue}" + f"on_kill called, " + "cancelling job with id {self.job_context.job_id} in queue " + f"{self.job_context.armada_queue}" ) - self._cancel_job() + self._cancel_job(self.job_context) - def _trigger_tracking_message(self): + def _trigger_tracking_message(self, job_id: str): if self.lookout_url_template: return ( f"Job details available at " - f'{self.lookout_url_template.replace("", self.job_id)}' + f'{self.lookout_url_template.replace("", job_id)}' ) return "" - def _execute_complete(self, _: Context, event: Dict[str, Any]): - if event["status"] == "error": - raise AirflowException(event["response"]) + def _deffered_yield(self, context: _RunningJobContext): + self.defer( + timeout=self.execution_timeout, + trigger=_ArmadaPollJobTrigger( + DateTime.utcnow() + datetime.timedelta(seconds=self.poll_interval), + context, + ), + method_name="_deffered_poll_for_termination", + ) + + @log_exceptions + def _deffered_poll_for_termination( + self, context: Context, event: Tuple[str, Dict[str, Any]] + ) -> None: + job_run_context = _RunningJobContext.from_payload(event[1]) + while job_run_context.state.is_active(): + job_run_context = self._check_job_status_and_fetch_logs(job_run_context) + if job_run_context.state.is_active(): + self._deffered_yield(job_run_context) + + self._running_job_terminated(job_run_context) def _reattach_or_submit_job( self, @@ -237,12 +343,15 @@ def _reattach_or_submit_job( ) if existing_id is not None: self.log.info( - f"Attached to existing job with id {existing_id['armada_job_id']}" + f"Attached to existing job with id {existing_id['armada_job_id']}." + f" {self._trigger_tracking_message(existing_id['armada_job_id'])}" ) return existing_id["armada_job_id"] job_id = self._submit_job(queue, job_set_id, job_request) - self.log.info(f"Submitted job with id {job_id}") + self.log.info( + f"Submitted job with id {job_id}. {self._trigger_tracking_message(job_id)}" + ) ti.xcom_push(key=f"{ti.try_number}", value={"armada_job_id": job_id}) return job_id @@ -266,61 +375,67 @@ def _submit_job( return job.job_id - def _poll_for_termination(self, tracking_message: str) -> None: - last_log_time = None - run_details = None - state = JobState.UNKNOWN + def _poll_for_termination(self, context: _RunningJobContext) -> None: + while context.state.is_active(): + context = self._check_job_status_and_fetch_logs(context) + if context.state.is_active(): + time.sleep(self.poll_interval) - start_time = time.time() - job_acknowledged = False - while state.is_active(): - response = self.client.get_job_status([self.job_id]) - state = JobState(response.job_states[self.job_id]) - self.log.info( - f"job {self.job_id} is in state: {state.name}. {tracking_message}" + self._running_job_terminated(context) + + def _running_job_terminated(self, context: _RunningJobContext): + self.log.info( + f"job {context.job_id} terminated with state: {context.state.name}" + ) + if context.state != JobState.SUCCEEDED: + raise AirflowException( + f"job {context.job_id} did not succeed. " + f"Final status was {context.state.name}" ) - if state != JobState.UNKNOWN: - job_acknowledged = True + @log_exceptions + def _check_job_status_and_fetch_logs( + self, context: _RunningJobContext + ) -> _RunningJobContext: + response = self.client.get_job_status([context.job_id]) + state = JobState(response.job_states[context.job_id]) + if state != context.state: + self.log.info( + f"job {context.job_id} is in state: {state.name}. " + f"{self._trigger_tracking_message(context.job_id)}" + ) + context.state = state + if context.state == JobState.UNKNOWN: if ( - not job_acknowledged - and int(time.time() - start_time) > self.job_acknowledgement_timeout + DateTime.utcnow().diff(context.start_time).in_seconds() + > self.job_acknowledgement_timeout ): self.log.info( - f"Job {self.job_id} not acknowledged by the Armada server within " + f"Job {context.job_id} not acknowledged by the Armada within " f"timeout ({self.job_acknowledgement_timeout}), terminating" ) - self.on_kill() - return - - if self.container_logs and not run_details: - if state == JobState.RUNNING or state.is_terminal(): - run_details = self._get_latest_job_run_details(self.job_id) - - if run_details: - try: - # pod_name format is sufficient for now. Ideally pod name should be - # retrieved from queryapi - log_status = self.pod_manager( - run_details.cluster - ).fetch_container_logs( - pod_name=f"armada-{self.job_id}-0", - namespace=self.job_request.namespace, - container_name=self.container_logs, - since_time=last_log_time, - ) - last_log_time = log_status.last_log_time - except Exception as e: - self.log.warning(f"Error fetching logs {e}") - - time.sleep(self.poll_interval) - - self.log.info(f"job {self.job_id} terminated with state: {state.name}") - if state != JobState.SUCCEEDED: - raise AirflowException( - f"job {self.job_id} did not succeed. Final status was {state.name}" - ) + self._cancel_job(context) + context.state = JobState.CANCELLED + return context + + if self.container_logs and not context.cluster: + if context.state == JobState.RUNNING or context.state.is_terminal(): + run_details = self._get_latest_job_run_details(context.job_id) + context.cluster = run_details.cluster + + if context.cluster: + try: + context.last_log_time = self.pod_manager.fetch_container_logs( + k8s_context=context.cluster, + namespace=self.job_request.namespace, + pod=f"armada-{context.job_id}-0", + container=self.container_logs, + since_time=context.last_log_time, + ) + except Exception as e: + self.log.warning(f"Error fetching logs {e}") + return context def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: job_details = self.client.get_job_details([job_id]).job_details[job_id] diff --git a/third_party/airflow/armada/triggers/armada.py b/third_party/airflow/armada/triggers/armada.py deleted file mode 100644 index 284fe305169..00000000000 --- a/third_party/airflow/armada/triggers/armada.py +++ /dev/null @@ -1,269 +0,0 @@ -import asyncio -import importlib -import time -from functools import cached_property -from typing import AsyncIterator, Any, Optional, Tuple, Dict - -from airflow.triggers.base import BaseTrigger, TriggerEvent -from armada_client.armada.job_pb2 import JobRunDetails -from armada_client.typings import JobState - -from armada_client.asyncio_client import ArmadaAsyncIOClient -from armada.auth import TokenRetriever -from armada.logs.pod_log_manager import PodLogManagerAsync -from armada.model import GrpcChannelArgs -from pendulum import DateTime - - -class ArmadaTrigger(BaseTrigger): - """ - An Airflow Trigger that can asynchronously manage an Armada job. - """ - - def __init__( - self, - job_id: str, - armada_queue: str, - job_set_id: str, - poll_interval: int, - tracking_message: str, - job_acknowledgement_timeout: int, - job_request_namespace: str, - channel_args: GrpcChannelArgs = None, - channel_args_details: Dict[str, Any] = None, - container_logs: Optional[str] = None, - k8s_token_retriever: Optional[TokenRetriever] = None, - k8s_token_retriever_details: Optional[Tuple[str, Dict[str, Any]]] = None, - last_log_time: Optional[DateTime] = None, - ): - """ - Initializes an instance of ArmadaTrigger, which is an Airflow trigger for - managing Armada jobs asynchronously. - - :param job_id: The unique identifier of the job to be monitored. - :type job_id: str - :param armada_queue: The Armada queue under which the job was submitted. - Required for job cancellation. - :type armada_queue: str - :param job_set_id: The unique identifier of the job set under which the job - was submitted. Required for job cancellation. - :type job_set_id: str - :param poll_interval: The interval, in seconds, at which the job status will be - checked. - :type poll_interval: int - :param tracking_message: A message to log or display for tracking the job - status. - :type tracking_message: str - :param job_acknowledgement_timeout: The timeout, in seconds, to wait for the job - to be acknowledged by Armada. - :type job_acknowledgement_timeout: int - :param job_request_namespace: The Kubernetes namespace under which the job was - submitted. - :type job_request_namespace: str - :param channel_args: The arguments to configure the gRPC channel. If None, - default arguments will be used. - :type channel_args: GrpcChannelArgs, optional - :param channel_args_details: Additional details or configurations for the gRPC - channel as a dictionary. Only used when - the trigger is rehydrated after serialization. - :type channel_args_details: dict[str, Any], optional - :param container_logs: Name of container from which to retrieve logs - :type container_logs: str, optional - :param k8s_token_retriever: An optional instance of type TokenRetriever, used to - refresh the Kubernetes auth token - :type k8s_token_retriever: TokenRetriever, optional - :param k8s_token_retriever_details: Configuration for TokenRetriever as a - dictionary. - Only used when the trigger is - rehydrated after serialization. - :type k8s_token_retriever_details: Tuple[str, Dict[str, Any]], optional - :param last_log_time: where to resume logs from - :type last_log_time: DateTime, optional - """ - super().__init__() - self.job_id = job_id - self.armada_queue = armada_queue - self.job_set_id = job_set_id - self.poll_interval = poll_interval - self.tracking_message = tracking_message - self.job_acknowledgement_timeout = job_acknowledgement_timeout - self.container_logs = container_logs - self.last_log_time = last_log_time - self.job_request_namespace = job_request_namespace - self._pod_manager = None - self.k8s_token_retriever = k8s_token_retriever - - if channel_args: - self.channel_args = channel_args - elif channel_args_details: - self.channel_args = GrpcChannelArgs(**channel_args_details) - else: - raise f"must provide either {channel_args} or {channel_args_details}" - - if k8s_token_retriever_details: - classpath, kwargs = k8s_token_retriever_details - module_path, class_name = classpath.rsplit( - ".", 1 - ) # Split the classpath to module and class name - module = importlib.import_module( - module_path - ) # Dynamically import the module - cls = getattr(module, class_name) # Get the class from the module - self.k8s_token_retriever = cls( - **kwargs - ) # Instantiate the class with the deserialized kwargs - - def serialize(self) -> tuple: - """ - Serialises the state of this Trigger. - When the Trigger is re-hydrated, these values will be passed to init() as kwargs - :return: - """ - k8s_token_retriever_details = ( - self.k8s_token_retriever.serialize() if self.k8s_token_retriever else None - ) - return ( - "armada.triggers.armada.ArmadaTrigger", - { - "job_id": self.job_id, - "armada_queue": self.armada_queue, - "job_set_id": self.job_set_id, - "channel_args_details": self.channel_args.serialize(), - "poll_interval": self.poll_interval, - "tracking_message": self.tracking_message, - "job_acknowledgement_timeout": self.job_acknowledgement_timeout, - "container_logs": self.container_logs, - "k8s_token_retriever_details": k8s_token_retriever_details, - "last_log_time": self.last_log_time, - "job_request_namespace": self.job_request_namespace, - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - """ - Run the Trigger Asynchronously. This will poll Armada until the Job reaches a - terminal state - """ - try: - response = await self._poll_for_termination(self.job_id) - yield TriggerEvent(response) - except Exception as exc: - yield TriggerEvent( - { - "status": "error", - "job_id": self.job_id, - "response": f"Job {self.job_id} did not succeed. Error was {exc}", - } - ) - - """Cannot call on_kill from trigger, will asynchronously cancel jobs instead.""" - - async def _cancel_job(self) -> None: - try: - result = await self.client.cancel_jobs( - queue=self.armada_queue, - job_set_id=self.job_set_id, - job_id=self.job_id, - ) - if len(list(result.cancelled_ids)) > 0: - self.log.info(f"Cancelled job with id {result.cancelled_ids}") - else: - self.log.warning(f"Failed to cancel job with id {self.job_id}") - except Exception as e: - self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") - - async def _poll_for_termination(self, job_id: str) -> Dict[str, Any]: - state = JobState.UNKNOWN - start_time = time.time() - job_acknowledged = False - run_details = None - - # Poll for terminal state - while state.is_active(): - resp = await self.client.get_job_status([job_id]) - state = JobState(resp.job_states[job_id]) - self.log.info( - f"Job {job_id} is in state: {state.name}. {self.tracking_message}" - ) - - if state != JobState.UNKNOWN: - job_acknowledged = True - - if ( - not job_acknowledged - and int(time.time() - start_time) > self.job_acknowledgement_timeout - ): - await self._cancel_job() - return { - "status": "error", - "job_id": job_id, - "response": f"Job {job_id} not acknowledged within timeout " - f"{self.job_acknowledgement_timeout}.", - } - - if self.container_logs and not run_details: - if state == JobState.RUNNING or state.is_terminal(): - run_details = await self._get_latest_job_run_details(self.job_id) - - if run_details: - try: - log_status = await self.pod_manager( - run_details.cluster - ).fetch_container_logs( - pod_name=f"armada-{self.job_id}-0", - namespace=self.job_request_namespace, - container_name=self.container_logs, - since_time=self.last_log_time, - ) - self.last_log_time = log_status.last_log_time - except Exception as e: - self.log.exception(e) - - if state.is_active(): - self.log.debug(f"Sleeping for {self.poll_interval} seconds") - await asyncio.sleep(self.poll_interval) - - self.log.info(f"Job {job_id} terminated with state:{state.name}") - if state != JobState.SUCCEEDED: - return { - "status": "error", - "job_id": job_id, - "response": f"Job {job_id} did not succeed. Final status was " - f"{state.name}", - } - return { - "status": "success", - "job_id": job_id, - "response": f"Job {job_id} succeeded", - } - - @cached_property - def client(self) -> ArmadaAsyncIOClient: - return ArmadaAsyncIOClient(channel=self.channel_args.aio_channel()) - - def pod_manager(self, k8s_context: str) -> PodLogManagerAsync: - if self._pod_manager is None: - self._pod_manager = PodLogManagerAsync( - k8s_context=k8s_context, token_retriever=self.k8s_token_retriever - ) - - return self._pod_manager - - async def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: - resp = await self.client.get_job_details([job_id]) - job_details = resp.job_details[job_id] - if job_details and job_details.latest_run_id: - for run in job_details.job_runs: - if run.run_id == job_details.latest_run_id: - return run - return None - - def __eq__(self, other): - if not isinstance(other, ArmadaTrigger): - return False - return ( - self.job_id == other.job_id - and self.channel_args.serialize() == other.channel_args.serialize() - and self.poll_interval == other.poll_interval - and self.tracking_message == other.tracking_message - ) diff --git a/third_party/airflow/docs/source/conf.py b/third_party/airflow/docs/source/conf.py index 15b6b10a646..10d3949aee8 100644 --- a/third_party/airflow/docs/source/conf.py +++ b/third_party/airflow/docs/source/conf.py @@ -12,6 +12,7 @@ # import os import sys + sys.path.insert(0, os.path.abspath('../..')) diff --git a/third_party/airflow/examples/bad_armada.py b/third_party/airflow/examples/bad_armada.py index 11bf545691e..137f4730791 100644 --- a/third_party/airflow/examples/bad_armada.py +++ b/third_party/airflow/examples/bad_armada.py @@ -1,19 +1,12 @@ +import pendulum from airflow import DAG from airflow.operators.bash import BashOperator - from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator - +from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -from armada_client.armada import ( - submit_pb2, -) - -import pendulum +from armada_client.k8s.io.apimachinery.pkg.api.resource import \ + generated_pb2 as api_resource def submit_sleep_container(image: str): diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index 5979e391f0b..ebd84d723ce 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -1,19 +1,12 @@ +import pendulum from airflow import DAG from airflow.operators.bash import BashOperator - from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator - +from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -from armada_client.armada import ( - submit_pb2, -) - -import pendulum +from armada_client.k8s.io.apimachinery.pkg.api.resource import \ + generated_pb2 as api_resource def submit_sleep_job(): diff --git a/third_party/airflow/examples/hello_armada.py b/third_party/airflow/examples/hello_armada.py index 0f59932d96c..d3120bdf5f6 100644 --- a/third_party/airflow/examples/hello_armada.py +++ b/third_party/airflow/examples/hello_armada.py @@ -1,19 +1,12 @@ +import pendulum from airflow import DAG from airflow.operators.bash import BashOperator - from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator - +from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -from armada_client.armada import ( - submit_pb2, -) - -import pendulum +from armada_client.k8s.io.apimachinery.pkg.api.resource import \ + generated_pb2 as api_resource def submit_sleep_job(): diff --git a/third_party/airflow/examples/hello_armada_deferrable.py b/third_party/airflow/examples/hello_armada_deferrable.py index eb028d61a40..f3e661875d0 100644 --- a/third_party/airflow/examples/hello_armada_deferrable.py +++ b/third_party/airflow/examples/hello_armada_deferrable.py @@ -1,19 +1,12 @@ +import pendulum from airflow import DAG from airflow.operators.bash import BashOperator - -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -from armada_client.armada import ( - submit_pb2, -) - -import pendulum - from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import \ + generated_pb2 as api_resource def submit_sleep_job(): diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index bde16313944..8f8fb538a57 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "armada_airflow" -version = "1.0.0" +version = "1.0.1" description = "Armada Airflow Operator" readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] diff --git a/third_party/airflow/test/integration/test_airflow_operator_logic.py b/third_party/airflow/test/integration/test_airflow_operator_logic.py index c2931715f70..4bc3c43418e 100644 --- a/third_party/airflow/test/integration/test_airflow_operator_logic.py +++ b/third_party/airflow/test/integration/test_airflow_operator_logic.py @@ -1,25 +1,21 @@ import os +import threading import uuid +from typing import Any from unittest.mock import MagicMock +import grpc import pytest -import threading - from airflow.exceptions import AirflowException -from armada_client.typings import JobState -from armada_client.armada import ( - submit_pb2, -) +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator +from armada_client.armada import submit_pb2 from armada_client.client import ArmadaClient from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( generated_pb2 as api_resource, ) -import grpc -from typing import Any - -from armada.model import GrpcChannelArgs -from armada.operators.armada import ArmadaOperator +from armada_client.typings import JobState DEFAULT_TASK_ID = "test_task_1" DEFAULT_DAG_ID = "test_dag_1" diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py index 90f6448defa..85129000ad1 100644 --- a/third_party/airflow/test/operators/test_armada.py +++ b/third_party/airflow/test/operators/test_armada.py @@ -1,19 +1,25 @@ import unittest +from datetime import timedelta from math import ceil -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, PropertyMock, patch from airflow.exceptions import AirflowException -from armada_client.armada import submit_pb2, job_pb2 +from armada.model import GrpcChannelArgs +from armada.operators.armada import ( + ArmadaOperator, + _ArmadaPollJobTrigger, + _RunningJobContext, +) +from armada_client.armada import job_pb2, submit_pb2 from armada_client.armada.submit_pb2 import JobSubmitRequestItem from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( generated_pb2 as api_resource, ) +from armada_client.typings import JobState +from pendulum import UTC, DateTime -from armada.model import GrpcChannelArgs -from armada.operators.armada import ArmadaOperator -from armada.triggers.armada import ArmadaTrigger - +DEFAULT_CURRENT_TIME = DateTime(2024, 8, 7, tzinfo=UTC) DEFAULT_JOB_ID = "test_job" DEFAULT_TASK_ID = "test_task_1" DEFAULT_DAG_ID = "test_dag_1" @@ -114,7 +120,9 @@ def test_execute(self, mock_client_fn, _): ) @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.on_kill", new_callable=PropertyMock) + @patch( + "armada.operators.armada.ArmadaOperator._cancel_job", new_callable=PropertyMock + ) @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _): operator = ArmadaOperator( @@ -139,7 +147,8 @@ def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _ ] self.context["ti"].xcom_pull.return_value = None - operator.execute(self.context) + with self.assertRaises(AirflowException): + operator.execute(self.context) self.assertEqual(mock_on_kill.call_count, 1) """We call on_kill by triggering the job unacknowledged timeout""" @@ -177,7 +186,8 @@ def test_on_kill_cancels_job(self, mock_client_fn, _): ] self.context["ti"].xcom_pull.return_value = None - operator.execute(self.context) + with self.assertRaises(AirflowException): + operator.execute(self.context) self.assertEqual(mock_client.cancel_jobs.call_count, 1) @patch("time.sleep", return_value=None) @@ -190,7 +200,7 @@ def test_job_reattaches(self, mock_client_fn, _): job_request=JobSubmitRequestItem(), task_id=DEFAULT_TASK_ID, deferrable=False, - job_acknowledgement_timeout=-1, + job_acknowledgement_timeout=10, ) # Set up Mock Armada @@ -198,7 +208,7 @@ def test_job_reattaches(self, mock_client_fn, _): mock_client.get_job_status.side_effect = [ job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) for x in [ - submit_pb2.UNKNOWN + submit_pb2.SUCCEEDED for _ in range( 1 + ceil( @@ -212,7 +222,6 @@ def test_job_reattaches(self, mock_client_fn, _): operator.execute(self.context) self.assertEqual(mock_client.submit_jobs.call_count, 0) - self.assertEqual(operator.job_id, DEFAULT_JOB_ID) class TestArmadaOperatorDeferrable(unittest.IsolatedAsyncioTestCase): @@ -228,9 +237,10 @@ def setUp(self): "dag": mock_dag, } + @patch("pendulum.DateTime.utcnow") @patch("armada.operators.armada.ArmadaOperator.defer") @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_execute_deferred(self, mock_client_fn, mock_defer_fn): + def test_execute_deferred(self, mock_client_fn, mock_defer_fn, mock_datetime_now): operator = ArmadaOperator( name="test", channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), @@ -240,6 +250,8 @@ def test_execute_deferred(self, mock_client_fn, mock_defer_fn): deferrable=True, ) + mock_datetime_now.return_value = DEFAULT_CURRENT_TIME + # Set up Mock Armada mock_client = MagicMock() mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( @@ -252,17 +264,19 @@ def test_execute_deferred(self, mock_client_fn, mock_defer_fn): self.assertEqual(mock_client.submit_jobs.call_count, 1) mock_defer_fn.assert_called_with( timeout=operator.execution_timeout, - trigger=ArmadaTrigger( - job_id=DEFAULT_JOB_ID, - armada_queue=DEFAULT_QUEUE, - job_set_id=operator.job_set_id, # Not relevant for the sake of test - channel_args=operator.channel_args, - poll_interval=operator.poll_interval, - tracking_message="", - job_acknowledgement_timeout=operator.job_acknowledgement_timeout, - job_request_namespace="default", + trigger=_ArmadaPollJobTrigger( + moment=DEFAULT_CURRENT_TIME + timedelta(seconds=operator.poll_interval), + context=_RunningJobContext( + armada_queue=DEFAULT_QUEUE, + job_set_id=operator.job_set_id, + job_id=DEFAULT_JOB_ID, + state=JobState.UNKNOWN, + start_time=DEFAULT_CURRENT_TIME, + cluster=None, + last_log_time=None, + ), ), - method_name="_execute_complete", + method_name="_deffered_poll_for_termination", ) def test_templating(self): diff --git a/third_party/airflow/test/triggers/__init__.py b/third_party/airflow/test/triggers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/third_party/airflow/test/triggers/test_armada.py b/third_party/airflow/test/triggers/test_armada.py deleted file mode 100644 index 29ba4f20990..00000000000 --- a/third_party/airflow/test/triggers/test_armada.py +++ /dev/null @@ -1,207 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, patch, PropertyMock - -from airflow.triggers.base import TriggerEvent -from armada_client.armada.submit_pb2 import JobState -from armada_client.armada import submit_pb2, job_pb2 - -from armada.model import GrpcChannelArgs -from armada.triggers.armada import ArmadaTrigger - -DEFAULT_JOB_ID = "test_job" -DEFAULT_QUEUE = "test_queue" -DEFAULT_JOB_SET_ID = "test_job_set_id" -DEFAULT_POLLING_INTERVAL = 30 -DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 - - -class AsyncMock(unittest.mock.MagicMock): # noqa: F811 - async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) - - -class TestArmadaTrigger(unittest.IsolatedAsyncioTestCase): - def setUp(self): - self.time = 0 - - def test_serialization(self): - trigger = ArmadaTrigger( - job_id=DEFAULT_JOB_ID, - armada_queue=DEFAULT_QUEUE, - job_set_id=DEFAULT_JOB_SET_ID, - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - poll_interval=30, - tracking_message="test tracking message", - job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, - job_request_namespace="default", - ) - classpath, kwargs = trigger.serialize() - self.assertEqual("armada.triggers.armada.ArmadaTrigger", classpath) - - rehydrated = ArmadaTrigger(**kwargs) - self.assertEqual(trigger, rehydrated) - - def _time_side_effect(self): - self.time += DEFAULT_POLLING_INTERVAL - return self.time - - @patch("time.time") - @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) - async def test_execute(self, mock_client_fn, _, time_time): - time_time.side_effect = self._time_side_effect - - test_cases = [ - { - "name": "Job Succeeds", - "statuses": [JobState.RUNNING, JobState.SUCCEEDED], - "expected_responses": [ - TriggerEvent( - { - "status": "success", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} succeeded", - } - ) - ], - }, - { - "name": "Job Failed", - "statuses": [JobState.RUNNING, JobState.FAILED], - "success": False, - "expected_responses": [ - TriggerEvent( - { - "status": "error", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} did not succeed. " - f"Final status was FAILED", - } - ) - ], - }, - { - "name": "Job cancelled", - "statuses": [JobState.RUNNING, JobState.CANCELLED], - "success": False, - "expected_responses": [ - TriggerEvent( - { - "status": "error", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} did not succeed." - f" Final status was CANCELLED", - } - ) - ], - }, - { - "name": "Job unacknowledged", - "statuses": [JobState.UNKNOWN for _ in range(6)], - "success": False, - "expected_responses": [ - TriggerEvent( - { - "status": "error", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} not acknowledged wit" - f"hin timeout {DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT}.", - } - ) - ], - }, - { - "name": "Job preempted", - "statuses": [JobState.RUNNING, JobState.PREEMPTED], - "success": False, - "expected_responses": [ - TriggerEvent( - { - "status": "error", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} did not succeed." - f" Final status was PREEMPTED", - } - ) - ], - }, - { - "name": "Job Succeeds but takes a lot of transitions", - "statuses": [ - JobState.SUBMITTED, - JobState.RUNNING, - JobState.RUNNING, - JobState.RUNNING, - JobState.RUNNING, - JobState.RUNNING, - JobState.SUCCEEDED, - ], - "success": True, - "expected_responses": [ - TriggerEvent( - { - "status": "success", - "job_id": DEFAULT_JOB_ID, - "response": f"Job {DEFAULT_JOB_ID} succeeded", - } - ) - ], - }, - ] - - for test_case in test_cases: - with self.subTest(test_case=test_case["name"]): - trigger = ArmadaTrigger( - job_id=DEFAULT_JOB_ID, - armada_queue=DEFAULT_QUEUE, - job_set_id=DEFAULT_JOB_SET_ID, - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - poll_interval=DEFAULT_POLLING_INTERVAL, - tracking_message="some tracking message", - job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, - job_request_namespace="default", - ) - - # Setup Mock Armada - mock_client = AsyncMock() - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in test_case["statuses"] - ] - mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( - cancelled_ids=[DEFAULT_JOB_ID] - ) - mock_client_fn.return_value = mock_client - responses = [gen async for gen in trigger.run()] - self.assertEqual(test_case["expected_responses"], responses) - self.assertEqual( - len(test_case["statuses"]), mock_client.get_job_status.call_count - ) - - @patch("time.sleep", return_value=None) - @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) - async def test_unacknowledged_results_in_job_cancel(self, mock_client_fn, _): - trigger = ArmadaTrigger( - job_id=DEFAULT_JOB_ID, - armada_queue=DEFAULT_QUEUE, - job_set_id=DEFAULT_JOB_SET_ID, - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - poll_interval=DEFAULT_POLLING_INTERVAL, - tracking_message="some tracking message", - job_acknowledgement_timeout=-1, - job_request_namespace="default", - ) - - # Set up Mock Armada - mock_client = AsyncMock() - mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( - cancelled_ids=[DEFAULT_JOB_ID] - ) - mock_client_fn.return_value = mock_client - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [JobState.UNKNOWN, JobState.UNKNOWN] - ] - [gen async for gen in trigger.run()] - - self.assertEqual(mock_client.cancel_jobs.call_count, 1)