From 8780f0de47a4b2d29c842016880fa4f9257f88a1 Mon Sep 17 00:00:00 2001 From: fotstrt Date: Sun, 10 Nov 2024 17:27:58 +0100 Subject: [PATCH] pass port --- sailor/controller.py | 19 +++++++++++-------- sailor/orchestration.proto | 5 +++-- sailor/orchestration_pb2.py | 12 ++++++------ sailor/run_ft_train.py | 2 +- sailor/run_train_custom.py | 3 ++- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/sailor/controller.py b/sailor/controller.py index 634f6ef6..f7a964d7 100644 --- a/sailor/controller.py +++ b/sailor/controller.py @@ -14,13 +14,14 @@ def get_slurm_nodelist(): return hostnames_list[:-1] class ClusterController: - def __init__(self, world_size: int, grpc_port: int) -> None: + def __init__(self, world_size: int, grpc_port: int, training_master_port: int) -> None: self.world_size = world_size self.hostnames = get_slurm_nodelist() addresses = ",".join(self.hostnames) os.environ["no_proxy"] = addresses self.num_nodes = len(self.hostnames) self.grpc_port = grpc_port + self.training_master_port = training_master_port print(f"Num nodes {self.num_nodes}, Hostnames: {self.hostnames}") self.alive_nodes = [] @@ -49,7 +50,7 @@ def monitor(self) -> None: break # broadcast new topology - self.send_new_topology(new_topology) + self.send_new_topology(self.training_master_port, new_topology) else: self.alive_nodes = new_alive_nodes time.sleep(10) @@ -62,13 +63,13 @@ def decide_topology(self) -> list[str]: return self.alive_nodes[:self.world_size] - def send_new_topology(self, topology: list[str]) -> None: + def send_new_topology(self, port: int, topology: list[str]) -> None: for node in self.alive_nodes: - self.send_new_topology_to_node(node, topology) + self.send_new_topology_to_node(node, port, topology) - def send_new_topology_to_node(self, node: str, topology: list[str]) -> None: - request = WorkerConfigurationRequest(topology=topology) + def send_new_topology_to_node(self, node: str, port: int, topology: list[str]) -> None: + request = WorkerConfigurationRequest(port=port, topology=topology) grpc_target = f'{node}:{self.grpc_port}' with grpc.insecure_channel(grpc_target) as channel: stub = WorkerAgentStub(channel) @@ -115,9 +116,11 @@ def kill_node(self, node: str, abort: bool) -> None: help='world_size (in number of nodes)', required=True) parser.add_argument('--grpc_port', type=int, help='Port to start grpc server', required=True) + parser.add_argument('--training_master_port', type=int, + help='Port used for training', required=True) args = parser.parse_args() time.sleep(10) # some sleep time to allow the workers to start their grpc servers - controller = ClusterController(args.world_size, args.grpc_port) - controller.monitor() \ No newline at end of file + controller = ClusterController(args.world_size, args.grpc_port, args.training_master_port) + controller.monitor() diff --git a/sailor/orchestration.proto b/sailor/orchestration.proto index 9d56b476..e14677f8 100644 --- a/sailor/orchestration.proto +++ b/sailor/orchestration.proto @@ -21,8 +21,9 @@ message KillResponse { } message WorkerConfigurationRequest { - repeated string topology = 1; + int32 port = 1; + repeated string topology = 2; } message WorkerConfigurationResponse { -} \ No newline at end of file +} diff --git a/sailor/orchestration_pb2.py b/sailor/orchestration_pb2.py index 960fc8c9..ae522ff0 100644 --- a/sailor/orchestration_pb2.py +++ b/sailor/orchestration_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13orchestration.proto\"\x13\n\x11\x43heckReadyRequest\"&\n\x12\x43heckReadyResponse\x12\x10\n\x08is_ready\x18\x01 \x01(\x08\"\x1c\n\x0bKillRequest\x12\r\n\x05\x61\x62ort\x18\x01 \x01(\x08\"\x0e\n\x0cKillResponse\".\n\x1aWorkerConfigurationRequest\x12\x10\n\x08topology\x18\x01 \x03(\t\"\x1d\n\x1bWorkerConfigurationResponse2\xc1\x01\n\x0bWorkerAgent\x12R\n\x13\x43onfigurationChange\x12\x1b.WorkerConfigurationRequest\x1a\x1c.WorkerConfigurationResponse\"\x00\x12\x37\n\nCheckReady\x12\x12.CheckReadyRequest\x1a\x13.CheckReadyResponse\"\x00\x12%\n\x04Kill\x12\x0c.KillRequest\x1a\r.KillResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13orchestration.proto\"\x13\n\x11\x43heckReadyRequest\"&\n\x12\x43heckReadyResponse\x12\x10\n\x08is_ready\x18\x01 \x01(\x08\"\x1c\n\x0bKillRequest\x12\r\n\x05\x61\x62ort\x18\x01 \x01(\x08\"\x0e\n\x0cKillResponse\"<\n\x1aWorkerConfigurationRequest\x12\x0c\n\x04port\x18\x01 \x01(\x05\x12\x10\n\x08topology\x18\x02 \x03(\t\"\x1d\n\x1bWorkerConfigurationResponse2\xc1\x01\n\x0bWorkerAgent\x12R\n\x13\x43onfigurationChange\x12\x1b.WorkerConfigurationRequest\x1a\x1c.WorkerConfigurationResponse\"\x00\x12\x37\n\nCheckReady\x12\x12.CheckReadyRequest\x1a\x13.CheckReadyResponse\"\x00\x12%\n\x04Kill\x12\x0c.KillRequest\x1a\r.KillResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -40,9 +40,9 @@ _globals['_KILLRESPONSE']._serialized_start=114 _globals['_KILLRESPONSE']._serialized_end=128 _globals['_WORKERCONFIGURATIONREQUEST']._serialized_start=130 - _globals['_WORKERCONFIGURATIONREQUEST']._serialized_end=176 - _globals['_WORKERCONFIGURATIONRESPONSE']._serialized_start=178 - _globals['_WORKERCONFIGURATIONRESPONSE']._serialized_end=207 - _globals['_WORKERAGENT']._serialized_start=210 - _globals['_WORKERAGENT']._serialized_end=403 + _globals['_WORKERCONFIGURATIONREQUEST']._serialized_end=190 + _globals['_WORKERCONFIGURATIONRESPONSE']._serialized_start=192 + _globals['_WORKERCONFIGURATIONRESPONSE']._serialized_end=221 + _globals['_WORKERAGENT']._serialized_start=224 + _globals['_WORKERAGENT']._serialized_end=417 # @@protoc_insertion_point(module_scope) diff --git a/sailor/run_ft_train.py b/sailor/run_ft_train.py index b42c1d80..22896441 100644 --- a/sailor/run_ft_train.py +++ b/sailor/run_ft_train.py @@ -45,7 +45,7 @@ def ConfigurationChange(self, request, context): topology_list = list(request.topology) if self.is_in_topo(topology_list): print(f"Starting new process, node rank is {self.node_rank}") - start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr}" + start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr} --master-port {request.port}" for i in range(self.gpus_per_node): print(f"Start for process {i}") rank_i = self.node_rank*self.gpus_per_node + i diff --git a/sailor/run_train_custom.py b/sailor/run_train_custom.py index 8667ce1c..f9eaf056 100644 --- a/sailor/run_train_custom.py +++ b/sailor/run_train_custom.py @@ -226,6 +226,7 @@ def get_args(): parser.add_argument("--world-size", type=int, required=True, help="World size") parser.add_argument("--rank", type=int, required=True, help="Rank") parser.add_argument("--master-ip", type=str, required=True, help="Master address") + parser.add_argument("--master-port", type=int, required=True, help="Master port") return parser.parse_args() @@ -238,7 +239,7 @@ def get_args(): os.environ['RANK'] = str(args.rank) os.environ['LOCAL_RANK'] = str(args.rank % 4) os.environ['MASTER_ADDR'] = args.master_ip - os.environ['MASTER_PORT'] = "1234" # TODO + os.environ['MASTER_PORT'] = str(args.master_port) # Load trainer and data