From ea9d4ab782914b4e189c2f36f8cc03dbeb9c8b3f Mon Sep 17 00:00:00 2001 From: Kian-Tat Lim Date: Tue, 8 Oct 2024 09:57:45 -0700 Subject: [PATCH] Explicitly control accepts. --- python/s3daemon/s3daemon.py | 52 ++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/python/s3daemon/s3daemon.py b/python/s3daemon/s3daemon.py index 97b9333..09d4b29 100644 --- a/python/s3daemon/s3daemon.py +++ b/python/s3daemon/s3daemon.py @@ -22,6 +22,7 @@ import asyncio import logging import os +import socket import time import aiobotocore.session @@ -30,6 +31,7 @@ max_connections = int(os.environ.get("S3DAEMON_MAX_CONNECTIONS", 25)) connect_timeout = float(os.environ.get("S3DAEMON_CONNECT_TIMEOUT", 5.0)) max_retries = int(os.environ.get("S3DAEMON_MAX_RETRIES", 2)) +max_clients = int(os.environ.get("S3DAEMON_MAX_CLIENTS", 25)) config = botocore.config.Config( max_pool_connections=max_connections, @@ -61,31 +63,29 @@ log.setLevel(logging.INFO) -async def handle_client(client, reader, writer): +async def handle_client(client, conn): """Handle a client connection to the server socket. Parameters ---------- client : `S3` The S3 client to use to talk to the server. - reader : `asyncio.StreamReader` - A stream connected to the socket to read the filename/destination pair. - writer : `asyncio.StreamWriter` - A stream connected to the socket to write back status information. + conn : `socket.Socket` + The socket connected to the client. """ - filename, dest = (await reader.readline()).decode("UTF-8").rstrip().split(" ") + filename, dest = conn.recv(4096).decode("UTF-8").rstrip().split(" ") start = time.time() # ignore the alias _, bucket, key = dest.split("/", maxsplit=2) try: with open(filename, "rb") as f: await client.put_object(Body=f, Bucket=bucket, Key=key) - writer.write(b"Success") - log.info("%f %f sec - %s", start, time.time() - start, filename) + conn.send(b"Success") + log.info("%f %f sec - %s", start, time.time() - start, key) except Exception as e: - writer.write(bytes(repr(e), "UTF-8")) - log.exception("%f %f sec - %s", start, time.time() - start, filename) - + conn.send(bytes(repr(e), "UTF-8")) + log.exception("%f %f sec - %s", start, time.time() - start, key) + conn.close() async def go(): """Start the server.""" @@ -97,14 +97,30 @@ async def go(): endpoint_url=endpoint_url, config=config, ) as client: - - async def client_cb(reader, writer): - await handle_client(client, reader, writer) - - server = await asyncio.start_server(client_cb, host, port) + sem = asyncio.Semaphore(max_clients) + background_tasks = set() log.info("Starting server") - async with server: - await server.serve_forever() + with socket.create_server((host, port)) as s: + # We don't want to block in accept(); we need to run other tasks. + s.setblocking(False) + while True: + # Do not allow more accepts if we're already handling the + # maximum number of clients. + await sem.acquire() + while True: + try: + conn, _ = s.accept() + break + except (TimeoutError, BlockingIOError): + # Allow other tasks to run. + await asyncio.sleep(0) + task = asyncio.create_task(handle_client(client, conn)) + # Add to set to avoid premature cleanup. + background_tasks.add(task) + # Release semaphore when task is handled. + task.add_done_callback(lambda _: sem.release()) + # Remove from set when finished. + task.add_done_callback(background_tasks.discard) def main():