Skip to content

Commit

Permalink
worker fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 8, 2024
1 parent ea1c23e commit 7c69b82
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 3 deletions.
48 changes: 48 additions & 0 deletions sailor/orchestration_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

183 changes: 183 additions & 0 deletions sailor/orchestration_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings

import orchestration_pb2 as orchestration__pb2

GRPC_GENERATED_VERSION = '1.66.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in orchestration_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)


class WorkerAgentStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ConfigurationChange = channel.unary_unary(
'/WorkerAgent/ConfigurationChange',
request_serializer=orchestration__pb2.WorkerConfigurationRequest.SerializeToString,
response_deserializer=orchestration__pb2.WorkerConfigurationResponse.FromString,
_registered_method=True)
self.CheckReady = channel.unary_unary(
'/WorkerAgent/CheckReady',
request_serializer=orchestration__pb2.CheckReadyRequest.SerializeToString,
response_deserializer=orchestration__pb2.CheckReadyResponse.FromString,
_registered_method=True)
self.Kill = channel.unary_unary(
'/WorkerAgent/Kill',
request_serializer=orchestration__pb2.KillRequest.SerializeToString,
response_deserializer=orchestration__pb2.KillResponse.FromString,
_registered_method=True)


class WorkerAgentServicer(object):
"""Missing associated documentation comment in .proto file."""

def ConfigurationChange(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def CheckReady(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Kill(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_WorkerAgentServicer_to_server(servicer, server):
rpc_method_handlers = {
'ConfigurationChange': grpc.unary_unary_rpc_method_handler(
servicer.ConfigurationChange,
request_deserializer=orchestration__pb2.WorkerConfigurationRequest.FromString,
response_serializer=orchestration__pb2.WorkerConfigurationResponse.SerializeToString,
),
'CheckReady': grpc.unary_unary_rpc_method_handler(
servicer.CheckReady,
request_deserializer=orchestration__pb2.CheckReadyRequest.FromString,
response_serializer=orchestration__pb2.CheckReadyResponse.SerializeToString,
),
'Kill': grpc.unary_unary_rpc_method_handler(
servicer.Kill,
request_deserializer=orchestration__pb2.KillRequest.FromString,
response_serializer=orchestration__pb2.KillResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'WorkerAgent', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('WorkerAgent', rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
class WorkerAgent(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def ConfigurationChange(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/WorkerAgent/ConfigurationChange',
orchestration__pb2.WorkerConfigurationRequest.SerializeToString,
orchestration__pb2.WorkerConfigurationResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def CheckReady(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/WorkerAgent/CheckReady',
orchestration__pb2.CheckReadyRequest.SerializeToString,
orchestration__pb2.CheckReadyResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def Kill(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/WorkerAgent/Kill',
orchestration__pb2.KillRequest.SerializeToString,
orchestration__pb2.KillResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
17 changes: 14 additions & 3 deletions sailor/run_ft_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import grpc
import os
import time
import socket
import argparse
Expand Down Expand Up @@ -47,7 +48,9 @@ class ElasticWorkerAgent(WorkerAgentServicer):
def __init__(self, script_args):
self.training_process = None
self.hostname = socket.gethostname()
self.world_size = 0
self.node_rank = -1
self.master_addr = None
self.script_args = script_args
print(f"Hello from grpc server {self.hostname}")

Expand All @@ -70,7 +73,7 @@ def ConfigurationChange(self, request, context):
topology_list = list(request.topology)
if self.is_in_topo(topology_list):
print(f"Starting new process, node rank is {self.node_rank}")
self.training_process = multiprocessing.Process(target=run)
self.training_process = multiprocessing.Process(target=run, args=(args.config_file, self.world_size, self.node_rank, self.master_addr))
self.training_process.start()

return WorkerConfigurationResponse()
Expand All @@ -79,6 +82,8 @@ def is_in_topo(self, topology):
if self.hostname not in topology:
return False
self.node_rank = topology.index(self.hostname)
self.world_size = len(topology)
self.master_addr = topology[0]
return True


Expand Down Expand Up @@ -265,7 +270,13 @@ def get_args():
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
return parser.parse_args()

def run(config_file):
def run(config_file, world_size, rank, master_addr):

os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = "1234" # TODO

trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)

Expand Down Expand Up @@ -297,4 +308,4 @@ def terminate(signum, _):
print("Start server!")
server.start()
signal.signal(signal.SIGTERM, terminate)
server.wait_for_termination()
server.wait_for_termination()

0 comments on commit 7c69b82

Please sign in to comment.