Skip to content

Commit

Permalink
Remove schedule downstream tasks after execution (aka "mini scheduler…
Browse files Browse the repository at this point in the history
…") (#43741)

This has been questionable how much benefit it actually had, but with the move
towards task DB isolation in Airflow 3 we won't be able to keep this anymore
(as we didn't when AIP-44 DB isolation was enabled), so lets remove it now.
  • Loading branch information
ashb authored Nov 6, 2024
1 parent c7c6547 commit d41c859
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 571 deletions.
9 changes: 0 additions & 9 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2418,15 +2418,6 @@ scheduler:
type: integer
default: "20"
see_also: ":ref:`scheduler:ha:tunables`"
schedule_after_task_execution:
description: |
Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the
same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other
dags in some circumstances
example: ~
version_added: 2.0.0
type: boolean
default: "True"
parsing_pre_import_modules:
description: |
The scheduler reads dag files to extract the airflow modules that are going to be used,
Expand Down
6 changes: 1 addition & 5 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def handle_task_exit(self, return_code: int) -> None:
self.terminating = True
self._log_return_code_metric(return_code)

if is_deferral := return_code == TaskReturnCode.DEFERRED.value:
if return_code == TaskReturnCode.DEFERRED.value:
self.log.info("Task exited with return code %s (task deferral)", return_code)
_set_task_deferred_context_var()
else:
Expand All @@ -262,10 +262,6 @@ def handle_task_exit(self, return_code: int) -> None:
message += ". For more information, see https://airflow.apache.org/docs/apache-airflow/stable/troubleshooting.html#LocalTaskJob-killed"
self.log.info(message)

if not (self.task_instance.test_mode or is_deferral):
if conf.getboolean("scheduler", "schedule_after_task_execution", fallback=True):
self.task_instance.schedule_downstream_tasks(max_tis_per_query=self.job.max_tis_per_query)

def on_kill(self):
self.task_runner.terminate()
self.task_runner.on_finish()
Expand Down
94 changes: 0 additions & 94 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3639,100 +3639,6 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum
return filters[0]
return or_(*filters)

@classmethod
@provide_session
def _schedule_downstream_tasks(
cls,
ti: TaskInstance | TaskInstancePydantic,
session: Session = NEW_SESSION,
max_tis_per_query: int | None = None,
):
from sqlalchemy.exc import OperationalError

from airflow.models.dagrun import DagRun

try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=ti.dag_id,
run_id=ti.run_id,
),
session=session,
skip_locked=True,
).one_or_none()

if not dag_run:
cls.logger().debug("Skip locked rows, rollback")
session.rollback()
return

task = ti.task
if TYPE_CHECKING:
assert task
assert task.dag

# Previously, this section used task.dag.partial_subset to retrieve a partial DAG.
# However, this approach is unsafe as it can result in incomplete or incorrect task execution,
# leading to potential bad cases. As a result, the operation has been removed.
# For more details, refer to the discussion in PR #[https://github.com/apache/airflow/pull/42582].
dag_run.dag = task.dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in task.dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [
ti
for ti in info.schedulable_tis
if ti.task_id not in skippable_task_ids
and not (
ti.task.inherits_from_empty_operator
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
)
]
for schedulable_ti in schedulable_tis:
if getattr(schedulable_ti, "task", None) is None:
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
cls.logger().info("%d downstream tasks scheduled from follow-on schedule check", num)

session.flush()

except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
cls.logger().warning(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()

@provide_session
def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
"""
Schedule downstream tasks of this task instance.
:meta: private
"""
try:
return TaskInstance._schedule_downstream_tasks(
ti=self, session=session, max_tis_per_query=max_tis_per_query
)
except Exception:
self.log.exception(
"Error scheduling downstream tasks. Skipping it as this is entirely optional optimisation. "
"There might be various reasons for it, please take a look at the stack trace to figure "
"out if the root cause can be diagnosed and fixed. See the issue "
"https://github.com/apache/airflow/issues/39717 for details and an example problem. If you "
"would like to get help in solving root cause, open discussion with all details with your "
"managed service support or in Airflow repository."
)

def get_relevant_upstream_map_indexes(
self,
upstream: Operator,
Expand Down
10 changes: 0 additions & 10 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,6 @@ def check_and_change_state_before_execution(
session=session,
)

def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
"""
Schedule downstream tasks of this task instance.
:meta: private
"""
# we should not schedule downstream tasks with Pydantic model because it will not be able to
# get the DAG object (we do not serialize it currently).
return

def command_as_list(
self,
mark_success: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,3 @@ However, you can also look at other non-performance-related scheduler configurat
in the loop. i.e. if it scheduled something then it will start the next loop
iteration straight away. This parameter is badly named (historical reasons) and it will be
renamed in the future with deprecation of the current name.

- :ref:`config:scheduler__schedule_after_task_execution`
Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of
the same DAG. Leaving this on will mean tasks in the same DAG execute quicker,
but might starve out other DAGs in some circumstances.
2 changes: 0 additions & 2 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def force_use_internal_api_on_edge_worker():
os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url
InternalApiConfig.set_use_internal_api("edge-worker")
# Disable mini-scheduler post task execution and leave next task schedule to core scheduler
os.environ["AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION"] = "False"


force_use_internal_api_on_edge_worker()
Expand Down
25 changes: 0 additions & 25 deletions providers/tests/cncf/kubernetes/decorators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,28 +215,3 @@ def f():
assert len(dag.task_group.children) == 1
teardown_task = dag.task_group.children["f"]
assert teardown_task.is_teardown


# Database isolation mode does not support mini-scheduler
@pytest.mark.skip_if_database_isolation_mode
def test_kubernetes_with_mini_scheduler(
dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
) -> None:
with dag_maker(session=session):

@task.kubernetes(
image="python:3.10-slim-buster",
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
def f(arg1, arg2, kwarg1=None, kwarg2=None):
return {"key1": "value1", "key2": "value2"}

f1 = f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
f.override(task_id="my_task_id2", do_xcom_push=False)("arg1", "arg2", kwarg1=f1)

dr = dag_maker.create_dagrun()
(ti, _) = dr.task_instances
# check that mini-scheduler works
ti.schedule_downstream_tasks()
158 changes: 0 additions & 158 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, LocalTaskJobRunner
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
Expand Down Expand Up @@ -845,162 +843,6 @@ def send_signal(ti, signal_sent, sig):
lines = f.readlines()
assert len(lines) == 0

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.parametrize(
"conf, init_state, first_run_state, second_run_state, task_ids_to_run, error_message",
[
(
{("scheduler", "schedule_after_task_execution"): "True"},
{"A": State.QUEUED, "B": State.NONE, "C": State.NONE},
{"A": State.SUCCESS, "B": State.SCHEDULED, "C": State.NONE},
{"A": State.SUCCESS, "B": State.SUCCESS, "C": State.SCHEDULED},
["A", "B"],
"A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
),
(
{("scheduler", "schedule_after_task_execution"): "False"},
{"A": State.QUEUED, "B": State.NONE, "C": State.NONE},
{"A": State.SUCCESS, "B": State.NONE, "C": State.NONE},
None,
["A", "B"],
"A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
),
(
{("scheduler", "schedule_after_task_execution"): "True"},
{"D": State.QUEUED, "E": State.NONE, "F": State.NONE, "G": State.NONE},
{"D": State.SUCCESS, "E": State.NONE, "F": State.NONE, "G": State.NONE},
None,
["D", "E"],
"G -> F -> E & D -> E, when D runs but F isn't QUEUED yet, E shouldn't be QUEUED.",
),
(
{("scheduler", "schedule_after_task_execution"): "True"},
{"H": State.QUEUED, "I": State.FAILED, "J": State.NONE},
{"H": State.SUCCESS, "I": State.FAILED, "J": State.UPSTREAM_FAILED},
None,
["H", "I"],
"H -> J & I -> J, when H is QUEUED but I has FAILED, J is marked UPSTREAM_FAILED.",
),
],
)
def test_fast_follow(
self,
conf,
init_state,
first_run_state,
second_run_state,
task_ids_to_run,
error_message,
get_test_dag,
):
with conf_vars(conf):
dag = get_test_dag(
"test_dagrun_fast_follow",
)

scheduler_job = Job()
scheduler_job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)
scheduler_job_runner.dagbag.bag_dag(dag)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}

dag_run = dag.create_dagrun(
run_id="test_dagrun_fast_follow", state=State.RUNNING, **triggered_by_kwargs
)

ti_by_task_id = {}
with create_session() as session:
for task_id in init_state:
ti = TaskInstance(dag.get_task(task_id), run_id=dag_run.run_id)
ti.refresh_from_db()
ti.state = init_state[task_id]
session.merge(ti)
ti_by_task_id[task_id] = ti

ti = TaskInstance(task=dag.get_task(task_ids_to_run[0]), run_id=dag_run.run_id)
ti.refresh_from_db()
job1 = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True)
job1.task_runner = StandardTaskRunner(job_runner)

run_job(job=job1, execute_callable=job_runner._execute)
self.validate_ti_states(dag_run, first_run_state, error_message)
if second_run_state:
ti = TaskInstance(task=dag.get_task(task_ids_to_run[1]), run_id=dag_run.run_id)
ti.refresh_from_db()
job2 = Job(dag_id=ti.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job2, task_instance=ti, ignore_ti_state=True)
job2.task_runner = StandardTaskRunner(job_runner)
run_job(job2, execute_callable=job_runner._execute)
self.validate_ti_states(dag_run, second_run_state, error_message)
if scheduler_job_runner.processor_agent:
scheduler_job_runner.processor_agent.end()

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@conf_vars({("scheduler", "schedule_after_task_execution"): "True"})
def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag):
dag = get_test_dag("test_dagrun_fast_follow")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dag.catchup = False
SerializedDagModel.write_dag(dag)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}

dr = dag.create_dagrun(
run_id="test_1",
state=State.RUNNING,
execution_date=DEFAULT_DATE,
data_interval=data_interval,
**triggered_by_kwargs,
)
dr2 = dag.create_dagrun(
run_id="test_2",
state=State.RUNNING,
execution_date=DEFAULT_DATE + datetime.timedelta(hours=1),
data_interval=data_interval,
**triggered_by_kwargs,
)
task_k = dag.get_task("K")
task_l = dag.get_task("L")
with create_session() as session:
ti_k = dr.get_task_instance(task_k.task_id, session=session)
ti_k.refresh_from_task(task_k)
ti_k.state = State.SUCCESS

ti_b = dr.get_task_instance(task_l.task_id, session=session)
ti_b.refresh_from_task(task_l)
ti_b.state = State.SUCCESS

ti2_k = dr2.get_task_instance(task_k.task_id, session=session)
ti2_k.refresh_from_task(task_k)
ti2_k.state = State.NONE

ti2_l = dr2.get_task_instance(task_l.task_id, session=session)
ti2_l.refresh_from_task(task_l)
ti2_l.state = State.NONE

session.merge(ti_k)
session.merge(ti_b)

session.merge(ti2_k)
session.merge(ti2_l)

job1 = Job(
executor=SequentialExecutor(),
dag_id=ti2_k.dag_id,
)
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti2_k, ignore_ti_state=True)
job1.task_runner = StandardTaskRunner(job_runner)
run_job(job=job1, execute_callable=job_runner._execute)

ti2_k.refresh_from_db()
ti2_l.refresh_from_db()
assert ti2_k.state == State.SUCCESS
assert ti2_l.state == State.NONE

failed_deps = list(ti2_l.get_failed_dep_statuses())
assert len(failed_deps) == 1
assert failed_deps[0].dep_name == "Previous Dagrun State"
assert not failed_deps[0].passed

def test_process_sigsegv_error_message(self, caplog, dag_maker):
"""Test that shows error if process failed with segmentation fault."""
caplog.set_level(logging.CRITICAL, logger="local_task_job.py")
Expand Down
Loading

0 comments on commit d41c859

Please sign in to comment.