diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index ec7cbf7fd941..6a79d033dbaf 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -586,15 +586,15 @@ def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str, self.event_buffer[key] = state, None def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: - tis_to_flush = [ti for ti in tis if not ti.external_executor_id] - scheduler_job_ids = [ti.external_executor_id for ti in tis] + tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id] + scheduler_job_ids = {ti.queued_by_job_id for ti in tis} pod_ids = { create_pod_id( dag_id=pod_generator.make_safe_label_value(ti.dag_id), task_id=pod_generator.make_safe_label_value(ti.task_id), ): ti for ti in tis - if ti.external_executor_id + if ti.queued_by_job_id } kube_client: client.CoreV1Api = self.kube_client for scheduler_job_id in scheduler_job_ids: diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 37cdf3e06b82..e4fff5f4ecb8 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -31,7 +31,7 @@ from contextlib import redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Callable, DefaultDict, Dict, Iterator, Iterable, List, Optional, Set, Tuple from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_, tuple_ @@ -1218,7 +1218,15 @@ def _process_executor_events(self, session: Session = None) -> int: # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) - tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all() + query = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')) + # row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have + # multi-schedulers + tis: Iterator[TI] = with_row_locks( + query, + of=TI, + session=session, + **skip_locked(session=session), + ) for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] buffer_key = ti.key.with_try_number(try_number) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 119116eb766e..dc201b7a61c4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -623,7 +623,10 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: self.priority_weight = ti.priority_weight self.operator = ti.operator self.queued_dttm = ti.queued_dttm + self.queued_by_job_id = ti.queued_by_job_id self.pid = ti.pid + self.executor_config = ti.executor_config + self.external_executor_id = ti.external_executor_id else: self.state = None diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 8d3d5b4d450e..ab476fdf0476 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -419,6 +419,82 @@ def test_not_adopt_unassigned_task(self, mock_kube_client): assert not mock_kube_client.patch_namespaced_pod.called assert pod_ids == {"foobar": {}} + @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.adopt_launched_task') + @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods') + def test_try_adopt_task_instances(self, mock_adopt_completed_pods, mock_adopt_launched_task): + executor = self.kubernetes_executor + executor.scheduler_job_id = "10" + mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", dag_id="dag", task_id="task") + pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo", labels={"dag_id": "dag", "task_id": "task"})) + pod_id = create_pod_id(dag_id="dag", task_id="task") + mock_kube_client = mock.MagicMock() + mock_kube_client.list_namespaced_pod.return_value.items = [pod] + executor.kube_client = mock_kube_client + + # First adoption + executor.try_adopt_task_instances([mock_ti]) + mock_kube_client.list_namespaced_pod.assert_called_once_with( + namespace='default', label_selector='airflow-worker=1' + ) + mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti}) + mock_adopt_completed_pods.assert_called_once() + # We aren't checking the return value of `try_adopt_task_instances` because it relies on + # `adopt_launched_task` mutating its arg. This should be refactored, but not right now. + + # Second adoption (queued_by_job_id and external_executor_id no longer match) + mock_kube_client.reset_mock() + mock_adopt_launched_task.reset_mock() + mock_adopt_completed_pods.reset_mock() + + mock_ti.queued_by_job_id = "10" # scheduler_job would have updated this after the first adoption + executor.scheduler_job_id = "20" + + executor.try_adopt_task_instances([mock_ti]) + mock_kube_client.list_namespaced_pod.assert_called_once_with( + namespace='default', label_selector='airflow-worker=10' + ) + mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti}) + mock_adopt_completed_pods.assert_called_once() + + @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods') + def test_try_adopt_task_instances_multiple_scheduler_ids(self, mock_adopt_completed_pods): + """We try to find pods only once per scheduler id""" + executor = self.kubernetes_executor + mock_kube_client = mock.MagicMock() + executor.kube_client = mock_kube_client + + mock_tis = [ + mock.MagicMock(queued_by_job_id="10", external_executor_id="1", dag_id="dag", task_id="task"), + mock.MagicMock(queued_by_job_id="40", external_executor_id="1", dag_id="dag", task_id="task2"), + mock.MagicMock(queued_by_job_id="40", external_executor_id="1", dag_id="dag", task_id="task3"), + ] + + executor.try_adopt_task_instances(mock_tis) + assert mock_kube_client.list_namespaced_pod.call_count == 2 + mock_kube_client.list_namespaced_pod.assert_has_calls( + [ + mock.call(namespace='default', label_selector='airflow-worker=10'), + mock.call(namespace='default', label_selector='airflow-worker=40'), + ], + any_order=True, + ) + + @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.adopt_launched_task') + @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods') + def test_try_adopt_task_instances_no_matching_pods( + self, mock_adopt_completed_pods, mock_adopt_launched_task + ): + executor = self.kubernetes_executor + mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", dag_id="dag", task_id="task") + mock_kube_client = mock.MagicMock() + mock_kube_client.list_namespaced_pod.return_value.items = [] + executor.kube_client = mock_kube_client + + tis_to_flush = executor.try_adopt_task_instances([mock_ti]) + assert tis_to_flush == [mock_ti] + mock_adopt_launched_task.assert_not_called() + mock_adopt_completed_pods.assert_called_once() + class TestKubernetesJobWatcher(unittest.TestCase): def setUp(self):