diff --git a/api/edge_api/identities/edge_request_forwarder.py b/api/edge_api/identities/edge_request_forwarder.py index 9ea0e65fbd80..42e7f64696b6 100644 --- a/api/edge_api/identities/edge_request_forwarder.py +++ b/api/edge_api/identities/edge_request_forwarder.py @@ -14,7 +14,7 @@ def _should_forward(project_id: int) -> bool: return bool(migrator.is_migration_done) -@register_task_handler() +@register_task_handler(queue_size=2000) def forward_identity_request( request_method: str, headers: dict, @@ -35,7 +35,7 @@ def forward_identity_request( requests.get(url, params=query_params, headers=headers, timeout=5) -@register_task_handler() +@register_task_handler(queue_size=2000) def forward_trait_request( request_method: str, headers: dict, @@ -52,7 +52,6 @@ def forward_trait_request_sync( return url = settings.EDGE_API_URL + "traits/" - payload = payload payload = json.dumps(payload) requests.post( url, @@ -62,7 +61,7 @@ def forward_trait_request_sync( ) -@register_task_handler() +@register_task_handler(queue_size=1000) def forward_trait_requests( request_method: str, headers: str, diff --git a/api/task_processor/decorators.py b/api/task_processor/decorators.py index 76c1d639b1cb..a942757960b6 100644 --- a/api/task_processor/decorators.py +++ b/api/task_processor/decorators.py @@ -8,7 +8,7 @@ from django.conf import settings from django.utils import timezone -from task_processor.exceptions import InvalidArgumentsError +from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError from task_processor.models import RecurringTask, Task from task_processor.task_registry import register_task from task_processor.task_run_method import TaskRunMethod @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def register_task_handler(task_name: str = None): +def register_task_handler(task_name: str = None, queue_size: int = None): def decorator(f: typing.Callable): nonlocal task_name @@ -49,12 +49,18 @@ def delay( run_in_thread(args=args, kwargs=kwargs) else: logger.debug("Creating task for function '%s'...", task_identifier) - task = Task.schedule_task( - schedule_for=delay_until or timezone.now(), - task_identifier=task_identifier, - args=args, - kwargs=kwargs, - ) + try: + task = Task.schedule_task( + schedule_for=delay_until or timezone.now(), + task_identifier=task_identifier, + queue_size=queue_size, + args=args, + kwargs=kwargs, + ) + except TaskQueueFullError as e: + logger.warning(e) + return + task.save() return task diff --git a/api/task_processor/exceptions.py b/api/task_processor/exceptions.py index 12cf27f73a7e..7f697a6e7ba3 100644 --- a/api/task_processor/exceptions.py +++ b/api/task_processor/exceptions.py @@ -4,3 +4,7 @@ class TaskProcessingError(Exception): class InvalidArgumentsError(TaskProcessingError): pass + + +class TaskQueueFullError(Exception): + pass diff --git a/api/task_processor/models.py b/api/task_processor/models.py index 87093f04fadc..3ef982c806b5 100644 --- a/api/task_processor/models.py +++ b/api/task_processor/models.py @@ -6,7 +6,7 @@ from django.db import models from django.utils import timezone -from task_processor.exceptions import TaskProcessingError +from task_processor.exceptions import TaskProcessingError, TaskQueueFullError from task_processor.managers import RecurringTaskManager, TaskManager from task_processor.task_registry import registered_tasks @@ -105,10 +105,22 @@ def schedule_task( cls, schedule_for: datetime, task_identifier: str, + queue_size: typing.Optional[int], *, args: typing.Tuple[typing.Any] = None, kwargs: typing.Dict[str, typing.Any] = None, ) -> "Task": + if queue_size: + if ( + cls.objects.filter( + task_identifier=task_identifier, completed=False, num_failures__lt=3 + ).count() + > queue_size + ): + raise TaskQueueFullError( + f"Queue for task {task_identifier} is full. " + f"Max queue size is {queue_size}" + ) task = cls.create( task_identifier=task_identifier, args=args, diff --git a/api/tests/unit/task_processor/test_unit_task_processor_models.py b/api/tests/unit/task_processor/test_unit_task_processor_models.py index 41d492146cf5..989a4a3573cd 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_models.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_models.py @@ -5,6 +5,7 @@ from django.utils import timezone from task_processor.decorators import register_task_handler +from task_processor.exceptions import TaskQueueFullError from task_processor.models import RecurringTask, Task now = timezone.now() @@ -54,3 +55,42 @@ def test_recurring_task_run_should_execute_first_run_at(first_run_time, expected ).should_execute == expected ) + + +def test_schedule_task_raises_error_if_queue_is_full(db): + # Given + task_identifier = "my_callable" + + # some incomplete task + for _ in range(10): + Task.objects.create(task_identifier=task_identifier) + + # When + with pytest.raises(TaskQueueFullError): + Task.schedule_task( + schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=9 + ) + + +def test_can_schedule_task_raises_error_if_queue_is_not_full(db): + # Given + task_identifier = "my_callable" + + # Some incomplete task + for _ in range(10): + Task.objects.create(task_identifier=task_identifier) + + # tasks with different identifiers + Task.objects.create(task_identifier="task_with_different_identifier") + + # failed tasks + Task.objects.create( + task_identifier="task_with_different_identifier", num_failures=3 + ) + + # When + task = Task.schedule_task( + schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=10 + ) + # Then + assert task is not None