Skip to content

Commit

Permalink
add searching for free ports in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 29, 2024
1 parent 5f3ce67 commit 9a03a04
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,25 @@
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple
import random
import socket

import torch.cuda
from nanotron.parallel import ParallelContext
from torch.distributed.launcher import elastic_launch


def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
while True:
port = random.randint(min_port, max_port)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError as e:
raise e

def available_gpus():
if not torch.cuda.is_available():
return 0
Expand Down Expand Up @@ -92,6 +105,8 @@ def _init_distributed(func):
"""
nb_gpus = tp * dp * pp
run_id = uuid.uuid4()

port = find_free_port()

config = torch.distributed.launcher.LaunchConfig(
min_nodes=1,
Expand All @@ -101,7 +116,7 @@ def _init_distributed(func):
rdzv_configs={"timeout": 60},
# Setting port to `0` allows `torch` to randomly pick a port: https://pytorch.org/docs/stable/elastic/run.html#stacked-single-node-multi-worker
# Works only for single node workload.
rdzv_endpoint="localhost:0",
rdzv_endpoint=f"localhost:{port}",
run_id=str(run_id),
max_restarts=0,
# TODO @thomasw21: Tune as we increase the number of tests
Expand Down

0 comments on commit 9a03a04

Please sign in to comment.