diff --git a/src/async_s3/list_objects_async.py b/src/async_s3/list_objects_async.py index a29b661..e77cade 100644 --- a/src/async_s3/list_objects_async.py +++ b/src/async_s3/list_objects_async.py @@ -1,6 +1,6 @@ import asyncio import functools -from typing import Iterable, Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, List, AsyncIterator, Set import aiobotocore.session import aiobotocore.client @@ -9,6 +9,9 @@ from async_s3.group_by_prefix import group_by_prefix +MAX_CONCURRENT_TASKS = 100 + + @functools.lru_cache() def create_session() -> aiobotocore.session.AioSession: """Create a session object.""" @@ -27,17 +30,21 @@ def get_s3_client() -> aiobotocore.client.AioBaseClient: class ListObjectsAsync: def __init__(self, bucket: str) -> None: self._bucket = bucket + self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - async def _list_objects( # pylint: disable=too-many-arguments + async def _list_objects( # pylint: disable=too-many-arguments, too-many-locals self, s3_client: aiobotocore.client.AioBaseClient, prefix: str, current_depth: int, max_depth: Optional[int], - max_folders: Optional[int] = None, - ) -> Tuple[Iterable[Dict[str, Any]], Iterable[Tuple[str, int]]]: + max_folders: Optional[int], + objects_keys: Set[str], + queue: asyncio.Queue[List[Dict[str, Any]]], + active_tasks: Set[asyncio.Task[None]], + ) -> None: + """Emit object pages to the queue.""" paginator = s3_client.get_paginator("list_objects_v2") - objects = [] prefixes = [] params = {"Bucket": self._bucket, "Prefix": prefix} @@ -45,57 +52,101 @@ async def _list_objects( # pylint: disable=too-many-arguments params["Delimiter"] = "/" async for page in paginator.paginate(**params): - for obj in page.get("Contents", []): - key: str = obj["Key"] - if key.endswith("/"): - continue # Omit "directories" - - objects.append(obj) + objects = page.get("Contents", []) + new_keys = { + obj["Key"] + for obj in objects + if not obj["Key"].endswith("/") and obj["Key"] not in objects_keys + } + cleared_objects = [obj for obj in objects if obj["Key"] in new_keys] + objects_keys.update(new_keys) + await queue.put(cleared_objects) if "Delimiter" in params: prefixes.extend([prefix["Prefix"] for prefix in page.get("CommonPrefixes", [])]) + level = -1 if current_depth == -1 else current_depth + 1 if max_folders is not None and (len(prefixes) > max_folders): - prefixes = [(key, -1) for key in group_by_prefix(prefixes, max_folders)] - else: - prefixes = [ - (key, -1 if current_depth == -1 else current_depth + 1) for key in prefixes - ] - return objects, prefixes - - async def list_objects( + prefixes = list(group_by_prefix(prefixes, max_folders)) + level = -1 + + for folder in prefixes: + await self.semaphore.acquire() + try: + task = asyncio.create_task( + self._list_objects( + s3_client, + folder, + level, + max_depth, + max_folders, + objects_keys, + queue, + active_tasks, + ) + ) + active_tasks.add(task) + task.add_done_callback(lambda t: self._task_done(t, active_tasks)) + except Exception as e: + self.semaphore.release() + raise e + + async def pages( self, prefix: str = "/", max_depth: Optional[int] = None, max_folders: Optional[int] = None - ) -> Iterable[Dict[str, Any]]: - """List all objects in the bucket with given prefix. + ) -> AsyncIterator[List[Dict[str, Any]]]: + """Generator that yields objects in the bucket with the given prefix. + + Yield objects by separate pages (list of AWS S3 object dicts). max_depth: The maximum folders depth to traverse in separate requests. If None, traverse all levels. max_folders: The maximum number of folders to load in separate requests. If None, requests all folders. Otherwise, the folders are grouped by prefixes before loading in separate requests. Try to group in the given number of folders if possible. """ - # if we have items with the same prefixes as folders we could have duplicates - # so we use dict to clear them out - objects: Dict[str, Dict[str, Any]] = {} - tasks = set() + # if we group by prefixes, some objects may be listed multiple times + # to avoid this, we store the keys of the objects already listed + objects_keys: Set[str] = set() + + # queue to store the objects pages from the tasks + queue: asyncio.Queue[List[Dict[str, Any]]] = asyncio.Queue() + + # set to keep track of active tasks + active_tasks: Set[asyncio.Task[None]] = set() async with get_s3_client() as s3_client: - tasks.add(asyncio.create_task(self._list_objects(s3_client, prefix, 0, max_depth))) - - while tasks: - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - tasks = pending - - for task in done: - files, folders = await task - objects.update({file["Key"]: file for file in files}) - - for folder, level in folders: - tasks.add( - asyncio.create_task( - self._list_objects( - s3_client, folder, level, max_depth, max_folders - ) - ) - ) - - return list(objects.values()) + root_task = asyncio.create_task( + self._list_objects( + s3_client, prefix, 0, max_depth, max_folders, objects_keys, queue, active_tasks + ) + ) + active_tasks.add(root_task) + root_task.add_done_callback(lambda t: self._task_done(t, active_tasks)) + + while active_tasks: + try: + yield queue.get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0) + + if active_tasks: + await asyncio.gather(*active_tasks) + + def _task_done(self, task: asyncio.Task[None], active_tasks: Set[asyncio.Task[None]]) -> None: + """Callback for when a task is done.""" + active_tasks.discard(task) + self.semaphore.release() + + async def list_objects( + self, prefix: str = "/", max_depth: Optional[int] = None, max_folders: Optional[int] = None + ) -> List[Dict[str, Any]]: + """List all objects in the bucket with the given prefix. + + max_depth: The maximum folders depth to traverse in separate requests. If None, traverse all levels. + max_folders: The maximum number of folders to load in separate requests. If None, requests all folders. + Otherwise, the folders are grouped by prefixes before loading in separate requests. + Try to group to the given `max_folders` if possible. + """ + objects = [] + async for objects_page in self.pages(prefix, max_depth, max_folders): + objects.extend(objects_page) + return objects diff --git a/tests/test_list_objects_async.py b/tests/test_list_objects_async.py index 645ed3f..d2cdd5a 100644 --- a/tests/test_list_objects_async.py +++ b/tests/test_list_objects_async.py @@ -44,7 +44,7 @@ async def test_list_objects_functional(mock_s3_structure): async def test_list_objects_with_max_depth(s3_client_proxy, mock_s3_structure): walker = ListObjectsAsync("mock-bucket") - objects = await walker.list_objects(prefix="root/", max_depth=1) + objects = await walker.list_objects(prefix="root/", max_depth=2) expected_keys = { 'root/data01/image01.png', 'root/data01/images/img11.jpg', @@ -70,13 +70,27 @@ async def test_list_objects_with_max_depth(s3_client_proxy, mock_s3_structure): call.get_paginator("list_objects_v2"), call.get_paginator().paginate(Bucket="mock-bucket", Prefix="root/", Delimiter="/"), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/', Delimiter='/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/', Delimiter='/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/', Delimiter='/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/', Delimiter='/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/archives/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/docs/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/images/'), + call.get_paginator('list_objects_v2'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/temp/'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/logs/'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/reports/'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/scripts/'), ] assert s3_client_proxy.calls == expected_calls @@ -118,26 +132,12 @@ async def test_list_objects_with_max_folders(s3_client_proxy, mock_s3_structure) call.get_paginator("list_objects_v2"), call.get_paginator().paginate(Bucket="mock-bucket", Prefix="root/", Delimiter="/"), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/', Delimiter='/'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/', Delimiter='/'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/', Delimiter='/'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/', Delimiter='/'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/a'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/d'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/i'), - call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/t'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/l'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/r'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03'), call.get_paginator('list_objects_v2'), - call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/s'), + call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04'), ] assert s3_client_proxy.calls == expected_calls \ No newline at end of file