From 94ed16ef36498a6c6ee30b441f6d36184a70f5e5 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 10 Apr 2024 12:49:32 -0400 Subject: [PATCH] fix: only stop GPU workers when 0 messages are active (#365) --- shapeworks_cloud/manage_workers.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/shapeworks_cloud/manage_workers.py b/shapeworks_cloud/manage_workers.py index eaf2880b..5c85c19b 100644 --- a/shapeworks_cloud/manage_workers.py +++ b/shapeworks_cloud/manage_workers.py @@ -11,7 +11,8 @@ def inspect_queue(queue_name): from .celery import app # this function requires pyrabbit and the rabbitmq management port - num_messages = -1 + num_messages_ready = -1 + num_messages_active = -1 with app.pool.acquire(block=True) as conn: try: manager = conn.get_manager() @@ -26,11 +27,14 @@ def inspect_queue(queue_name): ) vhost = manager.user queue = manager.get_queue(vhost, queue_name) - num_messages = queue.get('messages_ready', num_messages) + num_messages = queue.get('messages', -1) + num_messages_ready = queue.get('messages_ready', -1) + if num_messages >= 0: + num_messages_active = num_messages - num_messages_ready except pyrabbit.http.HTTPError: # queue doesn't exist yet, wait for a spawned task to create it pass - return num_messages + return num_messages_ready, num_messages_active def get_all_workers(client): @@ -76,18 +80,18 @@ def manage_workers(**kwargs): if v is not None: os.environ[k] = v - num_queued = inspect_queue('gpu') - if num_queued < 0: + num_messages_ready, num_messages_active = inspect_queue('gpu') + if num_messages_ready < 0: return - print(f'{num_queued} tasks in queue.') + print(f'{num_messages_ready} tasks ready, {num_messages_active} tasks active.') client = boto3.client('ec2') gpu_workers = get_gpu_workers(client) - if num_queued > 0: + if num_messages_ready > 0: ids_to_start = [w['id'] for w in gpu_workers if not w['hostname']] - if len(ids_to_start) > num_queued: - ids_to_start = ids_to_start[:num_queued] + if len(ids_to_start) > num_messages_ready: + ids_to_start = ids_to_start[:num_messages_ready] if len(ids_to_start) > 0: print(f'Starting instances {ids_to_start}.') @@ -95,7 +99,7 @@ def manage_workers(**kwargs): else: print('All available GPU workers are live. Tasks in queue must wait.') - else: + elif num_messages_active == 0: ids_to_stop = [w['id'] for w in gpu_workers if w['hostname']] if len(ids_to_stop) > 0: