Skip to content

Commit

Permalink
fix: only stop GPU workers when 0 messages are active (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
annehaley authored Apr 10, 2024
1 parent 7fcf40c commit 94ed16e
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions shapeworks_cloud/manage_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def inspect_queue(queue_name):
from .celery import app

# this function requires pyrabbit and the rabbitmq management port
num_messages = -1
num_messages_ready = -1
num_messages_active = -1
with app.pool.acquire(block=True) as conn:
try:
manager = conn.get_manager()
Expand All @@ -26,11 +27,14 @@ def inspect_queue(queue_name):
)
vhost = manager.user
queue = manager.get_queue(vhost, queue_name)
num_messages = queue.get('messages_ready', num_messages)
num_messages = queue.get('messages', -1)
num_messages_ready = queue.get('messages_ready', -1)
if num_messages >= 0:
num_messages_active = num_messages - num_messages_ready
except pyrabbit.http.HTTPError:
# queue doesn't exist yet, wait for a spawned task to create it
pass
return num_messages
return num_messages_ready, num_messages_active


def get_all_workers(client):
Expand Down Expand Up @@ -76,26 +80,26 @@ def manage_workers(**kwargs):
if v is not None:
os.environ[k] = v

num_queued = inspect_queue('gpu')
if num_queued < 0:
num_messages_ready, num_messages_active = inspect_queue('gpu')
if num_messages_ready < 0:
return
print(f'{num_queued} tasks in queue.')
print(f'{num_messages_ready} tasks ready, {num_messages_active} tasks active.')

client = boto3.client('ec2')
gpu_workers = get_gpu_workers(client)

if num_queued > 0:
if num_messages_ready > 0:
ids_to_start = [w['id'] for w in gpu_workers if not w['hostname']]
if len(ids_to_start) > num_queued:
ids_to_start = ids_to_start[:num_queued]
if len(ids_to_start) > num_messages_ready:
ids_to_start = ids_to_start[:num_messages_ready]

if len(ids_to_start) > 0:
print(f'Starting instances {ids_to_start}.')
print(client.start_instances(InstanceIds=ids_to_start))
else:
print('All available GPU workers are live. Tasks in queue must wait.')

else:
elif num_messages_active == 0:
ids_to_stop = [w['id'] for w in gpu_workers if w['hostname']]

if len(ids_to_stop) > 0:
Expand Down

0 comments on commit 94ed16e

Please sign in to comment.