diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index fa7dbff1cf..01721c85da 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -129,9 +129,11 @@ def _spawn_worker(self, num: int = 1): wdir_hash = hashlib.sha256(self.wdir.encode("utf-8")).hexdigest()[:6] node_name = f"dvc-exp-{wdir_hash}-{num}@localhost" cmd = ["exp", "queue-worker", node_name] - if num == 1: - # automatically run celery cleanup when primary worker shuts down - cmd.append("--clean") + + # Always clean the queues as non expired messages will be excluded by dvc_task + # effectively skipping the cleaning. + cmd.append("--clean") + if logger.getEffectiveLevel() <= logging.DEBUG: cmd.append("-v") name = f"dvc-exp-worker-{num}" diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index df7bd52156..ab730d5bc4 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,6 +1,7 @@ import itertools import logging import os +import re import stat from textwrap import dedent @@ -27,6 +28,7 @@ from dvc.stage.exceptions import StageFileDoesNotExistError from dvc.testing.scripts import COPY_SCRIPT from dvc.utils.serialize import PythonFileCorruptedError +from dvc_task.proc.process import ManagedProcess @pytest.mark.parametrize("name", [None, "foo"]) @@ -479,6 +481,32 @@ def test_run_celery(tmp_dir, scm, dvc, exp_stage, mocker): assert expected == metrics +def test_run_celery_queues_two_jobs_each_one_with_cleaning_flag( + tmp_dir, scm, dvc, exp_stage, mocker +): + dvc.experiments.run(exp_stage.addressing, params=["foo=2"], queue=True) + dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True) + assert len(dvc.experiments.stash_revs) == 2 + + repro_spy = mocker.spy(dvc.experiments, "reproduce_celery") + spawn_spy = mocker.spy(ManagedProcess, "spawn") + dvc.experiments.run(run_all=True, jobs=2) + + repro_spy.assert_called_once_with(jobs=2) + + call_1_args, call_2_args = [spawn_spy.call_args_list[n].args[0] for n in (0, 1)] + + pattern = r"^dvc-exp-[0-9A-Fa-f]{6}-[1,2]@localhost$" # dvc-exp-4c8d13-1@localhost + first_queue, second_queue = [ + re.match(pattern, call_arg[3]).group(0) + for call_arg in [call_1_args, call_2_args] + ] + + assert call_1_args == ["dvc", "exp", "queue-worker", first_queue, "--clean", "-v"] + assert call_2_args == ["dvc", "exp", "queue-worker", second_queue, "--clean", "-v"] + assert spawn_spy.call_count == 2 + + def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage): from dvc.utils.fs import remove diff --git a/tests/unit/command/test_queue.py b/tests/unit/command/test_queue.py index 19618eb3c8..aab7abfe61 100644 --- a/tests/unit/command/test_queue.py +++ b/tests/unit/command/test_queue.py @@ -96,6 +96,10 @@ def test_experiments_start(dvc, scm, mocker): assert cmd.run() == 0 assert m.call_count == 3 + # Ensure each call to _spawn_worker will be for the nth worker + for n, call_arg in enumerate(m.call_args_list, start=1): + assert call_arg[0][0] == n + def test_experiments_stop(dvc, scm, mocker): cli_args = parse_args(["queue", "stop", "--kill"])