Skip to content

Commit

Permalink
Reduce CLIENT_TIMEOUT_SEC in benchmarking script (#932)
Browse files Browse the repository at this point in the history
first commit
  • Loading branch information
Bslabe123 authored Jan 15, 2025
1 parent c985e95 commit 35d41d9
Showing 1 changed file with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from google.protobuf.timestamp_pb2 import Timestamp

MIN_SEQ_LEN = 4
CLIENT_TIMEOUT_SEC = 3 * 60 * 60
NEW_TEXT_KEY = "\nOutput:\n"
PROMETHEUS_PORT = 9090

Expand Down Expand Up @@ -148,6 +147,7 @@ async def send_stream_request(
tokenizer: PreTrainedTokenizerBase,
sax_model: str,
model: str,
timeout: float,
) -> Tuple[Tuple[int, int, float], float, Dict[str, int]]:
"""Sends stream request to server"""
request_start_time = time.time()
Expand Down Expand Up @@ -179,7 +179,7 @@ async def send_stream_request(
ttft = 0.0
st = time.perf_counter()
output = ""
timeout = aiohttp.ClientTimeout(total=CLIENT_TIMEOUT_SEC)
timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=timeout,trust_env=True) as session:
try:
async with session.post(api_url, headers=headers, json=pload, ssl=False) as response:
Expand Down Expand Up @@ -249,6 +249,7 @@ async def send_request(
tokenizer: PreTrainedTokenizerBase,
sax_model: str,
model: str,
timeout: float,
) -> Tuple[Tuple[int, int, float], float, Dict[str, int]]:
"""Sends request to server."""
request_start_time = time.time()
Expand Down Expand Up @@ -322,7 +323,7 @@ async def send_request(
raise ValueError(f"Unknown backend: {backend}")

# Set client timeout to be 3 hrs.
timeout = aiohttp.ClientTimeout(total=CLIENT_TIMEOUT_SEC)
timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=timeout,trust_env=True,trace_configs=[trace_config]) as session:
while True:
try:
Expand Down Expand Up @@ -426,6 +427,7 @@ async def benchmark(
tokenizer,
args.sax_model,
model,
args.request_timeout,
)
)
else:
Expand All @@ -442,6 +444,7 @@ async def benchmark(
tokenizer,
args.sax_model,
model,
args.request_timeout,
)
)
tasks.append(task)
Expand Down Expand Up @@ -834,6 +837,12 @@ async def main(args: argparse.Namespace):
action="store_true",
help="Whether to stream the request. Needed for TTFT metric",
)
parser.add_argument(
"--request-timeout",
type=float,
default=(3.0 * 60.0 * 60.0),
help="Individual request timeout",
)
parser.add_argument(
"--tokenizer",
type=str,
Expand Down

0 comments on commit 35d41d9

Please sign in to comment.