Skip to content

Commit

Permalink
pass port
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 10, 2024
1 parent 66d797c commit 8780f0d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
19 changes: 11 additions & 8 deletions sailor/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ def get_slurm_nodelist():
return hostnames_list[:-1]

class ClusterController:
def __init__(self, world_size: int, grpc_port: int) -> None:
def __init__(self, world_size: int, grpc_port: int, training_master_port: int) -> None:
self.world_size = world_size
self.hostnames = get_slurm_nodelist()
addresses = ",".join(self.hostnames)
os.environ["no_proxy"] = addresses
self.num_nodes = len(self.hostnames)
self.grpc_port = grpc_port
self.training_master_port = training_master_port

print(f"Num nodes {self.num_nodes}, Hostnames: {self.hostnames}")
self.alive_nodes = []
Expand Down Expand Up @@ -49,7 +50,7 @@ def monitor(self) -> None:
break

# broadcast new topology
self.send_new_topology(new_topology)
self.send_new_topology(self.training_master_port, new_topology)
else:
self.alive_nodes = new_alive_nodes
time.sleep(10)
Expand All @@ -62,13 +63,13 @@ def decide_topology(self) -> list[str]:
return self.alive_nodes[:self.world_size]


def send_new_topology(self, topology: list[str]) -> None:
def send_new_topology(self, port: int, topology: list[str]) -> None:
for node in self.alive_nodes:
self.send_new_topology_to_node(node, topology)
self.send_new_topology_to_node(node, port, topology)


def send_new_topology_to_node(self, node: str, topology: list[str]) -> None:
request = WorkerConfigurationRequest(topology=topology)
def send_new_topology_to_node(self, node: str, port: int, topology: list[str]) -> None:
request = WorkerConfigurationRequest(port=port, topology=topology)
grpc_target = f'{node}:{self.grpc_port}'
with grpc.insecure_channel(grpc_target) as channel:
stub = WorkerAgentStub(channel)
Expand Down Expand Up @@ -115,9 +116,11 @@ def kill_node(self, node: str, abort: bool) -> None:
help='world_size (in number of nodes)', required=True)
parser.add_argument('--grpc_port', type=int,
help='Port to start grpc server', required=True)
parser.add_argument('--training_master_port', type=int,
help='Port used for training', required=True)

args = parser.parse_args()

time.sleep(10) # some sleep time to allow the workers to start their grpc servers
controller = ClusterController(args.world_size, args.grpc_port)
controller.monitor()
controller = ClusterController(args.world_size, args.grpc_port, args.training_master_port)
controller.monitor()
5 changes: 3 additions & 2 deletions sailor/orchestration.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ message KillResponse {
}

message WorkerConfigurationRequest {
repeated string topology = 1;
int32 port = 1;
repeated string topology = 2;
}

message WorkerConfigurationResponse {
}
}
12 changes: 6 additions & 6 deletions sailor/orchestration_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sailor/run_ft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def ConfigurationChange(self, request, context):
topology_list = list(request.topology)
if self.is_in_topo(topology_list):
print(f"Starting new process, node rank is {self.node_rank}")
start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr}"
start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr} --master-port {request.port}"
for i in range(self.gpus_per_node):
print(f"Start for process {i}")
rank_i = self.node_rank*self.gpus_per_node + i
Expand Down
3 changes: 2 additions & 1 deletion sailor/run_train_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def get_args():
parser.add_argument("--world-size", type=int, required=True, help="World size")
parser.add_argument("--rank", type=int, required=True, help="Rank")
parser.add_argument("--master-ip", type=str, required=True, help="Master address")
parser.add_argument("--master-port", type=int, required=True, help="Master port")

return parser.parse_args()

Expand All @@ -238,7 +239,7 @@ def get_args():
os.environ['RANK'] = str(args.rank)
os.environ['LOCAL_RANK'] = str(args.rank % 4)
os.environ['MASTER_ADDR'] = args.master_ip
os.environ['MASTER_PORT'] = "1234" # TODO
os.environ['MASTER_PORT'] = str(args.master_port)


# Load trainer and data
Expand Down

0 comments on commit 8780f0d

Please sign in to comment.