Skip to content

Commit

Permalink
https://github.com/andgineer/async-s3/issues/11
Browse files Browse the repository at this point in the history
#8
list objects generator with internal removing duplicates
  • Loading branch information
andgineer committed Jun 14, 2024
1 parent 67852b7 commit 576b6ae
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 68 deletions.
141 changes: 96 additions & 45 deletions src/async_s3/list_objects_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import functools
from typing import Iterable, Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, List, AsyncIterator, Set

import aiobotocore.session
import aiobotocore.client
Expand All @@ -9,6 +9,9 @@
from async_s3.group_by_prefix import group_by_prefix


MAX_CONCURRENT_TASKS = 100


@functools.lru_cache()
def create_session() -> aiobotocore.session.AioSession:
"""Create a session object."""
Expand All @@ -27,75 +30,123 @@ def get_s3_client() -> aiobotocore.client.AioBaseClient:
class ListObjectsAsync:
def __init__(self, bucket: str) -> None:
self._bucket = bucket
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)

async def _list_objects( # pylint: disable=too-many-arguments
async def _list_objects( # pylint: disable=too-many-arguments, too-many-locals
self,
s3_client: aiobotocore.client.AioBaseClient,
prefix: str,
current_depth: int,
max_depth: Optional[int],
max_folders: Optional[int] = None,
) -> Tuple[Iterable[Dict[str, Any]], Iterable[Tuple[str, int]]]:
max_folders: Optional[int],
objects_keys: Set[str],
queue: asyncio.Queue[List[Dict[str, Any]]],
active_tasks: Set[asyncio.Task[None]],
) -> None:
"""Emit object pages to the queue."""
paginator = s3_client.get_paginator("list_objects_v2")
objects = []
prefixes = []

params = {"Bucket": self._bucket, "Prefix": prefix}
if (current_depth != -1) and (max_depth is None or current_depth < max_depth):
params["Delimiter"] = "/"

async for page in paginator.paginate(**params):
for obj in page.get("Contents", []):
key: str = obj["Key"]
if key.endswith("/"):
continue # Omit "directories"

objects.append(obj)
objects = page.get("Contents", [])
new_keys = {
obj["Key"]
for obj in objects
if not obj["Key"].endswith("/") and obj["Key"] not in objects_keys
}
cleared_objects = [obj for obj in objects if obj["Key"] in new_keys]
objects_keys.update(new_keys)
await queue.put(cleared_objects)

if "Delimiter" in params:
prefixes.extend([prefix["Prefix"] for prefix in page.get("CommonPrefixes", [])])

level = -1 if current_depth == -1 else current_depth + 1
if max_folders is not None and (len(prefixes) > max_folders):
prefixes = [(key, -1) for key in group_by_prefix(prefixes, max_folders)]
else:
prefixes = [
(key, -1 if current_depth == -1 else current_depth + 1) for key in prefixes
]
return objects, prefixes

async def list_objects(
prefixes = list(group_by_prefix(prefixes, max_folders))
level = -1

for folder in prefixes:
await self.semaphore.acquire()
try:
task = asyncio.create_task(
self._list_objects(
s3_client,
folder,
level,
max_depth,
max_folders,
objects_keys,
queue,
active_tasks,
)
)
active_tasks.add(task)
task.add_done_callback(lambda t: self._task_done(t, active_tasks))
except Exception as e:
self.semaphore.release()
raise e

async def pages(
self, prefix: str = "/", max_depth: Optional[int] = None, max_folders: Optional[int] = None
) -> Iterable[Dict[str, Any]]:
"""List all objects in the bucket with given prefix.
) -> AsyncIterator[List[Dict[str, Any]]]:
"""Generator that yields objects in the bucket with the given prefix.
Yield objects by separate pages (list of AWS S3 object dicts).
max_depth: The maximum folders depth to traverse in separate requests. If None, traverse all levels.
max_folders: The maximum number of folders to load in separate requests. If None, requests all folders.
Otherwise, the folders are grouped by prefixes before loading in separate requests.
Try to group in the given number of folders if possible.
"""
# if we have items with the same prefixes as folders we could have duplicates
# so we use dict to clear them out
objects: Dict[str, Dict[str, Any]] = {}
tasks = set()
# if we group by prefixes, some objects may be listed multiple times
# to avoid this, we store the keys of the objects already listed
objects_keys: Set[str] = set()

# queue to store the objects pages from the tasks
queue: asyncio.Queue[List[Dict[str, Any]]] = asyncio.Queue()

# set to keep track of active tasks
active_tasks: Set[asyncio.Task[None]] = set()

async with get_s3_client() as s3_client:
tasks.add(asyncio.create_task(self._list_objects(s3_client, prefix, 0, max_depth)))

while tasks:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
tasks = pending

for task in done:
files, folders = await task
objects.update({file["Key"]: file for file in files})

for folder, level in folders:
tasks.add(
asyncio.create_task(
self._list_objects(
s3_client, folder, level, max_depth, max_folders
)
)
)

return list(objects.values())
root_task = asyncio.create_task(
self._list_objects(
s3_client, prefix, 0, max_depth, max_folders, objects_keys, queue, active_tasks
)
)
active_tasks.add(root_task)
root_task.add_done_callback(lambda t: self._task_done(t, active_tasks))

while active_tasks:
try:
yield queue.get_nowait()
except asyncio.QueueEmpty:
await asyncio.sleep(0)

if active_tasks:
await asyncio.gather(*active_tasks)

def _task_done(self, task: asyncio.Task[None], active_tasks: Set[asyncio.Task[None]]) -> None:
"""Callback for when a task is done."""
active_tasks.discard(task)
self.semaphore.release()

async def list_objects(
self, prefix: str = "/", max_depth: Optional[int] = None, max_folders: Optional[int] = None
) -> List[Dict[str, Any]]:
"""List all objects in the bucket with the given prefix.
max_depth: The maximum folders depth to traverse in separate requests. If None, traverse all levels.
max_folders: The maximum number of folders to load in separate requests. If None, requests all folders.
Otherwise, the folders are grouped by prefixes before loading in separate requests.
Try to group to the given `max_folders` if possible.
"""
objects = []
async for objects_page in self.pages(prefix, max_depth, max_folders):
objects.extend(objects_page)
return objects
46 changes: 23 additions & 23 deletions tests/test_list_objects_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def test_list_objects_functional(mock_s3_structure):
async def test_list_objects_with_max_depth(s3_client_proxy, mock_s3_structure):
walker = ListObjectsAsync("mock-bucket")

objects = await walker.list_objects(prefix="root/", max_depth=1)
objects = await walker.list_objects(prefix="root/", max_depth=2)
expected_keys = {
'root/data01/image01.png',
'root/data01/images/img11.jpg',
Expand All @@ -70,13 +70,27 @@ async def test_list_objects_with_max_depth(s3_client_proxy, mock_s3_structure):
call.get_paginator("list_objects_v2"),
call.get_paginator().paginate(Bucket="mock-bucket", Prefix="root/", Delimiter="/"),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/archives/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/docs/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/images/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/temp/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/logs/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/reports/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/scripts/'),
]
assert s3_client_proxy.calls == expected_calls

Expand Down Expand Up @@ -118,26 +132,12 @@ async def test_list_objects_with_max_folders(s3_client_proxy, mock_s3_structure)
call.get_paginator("list_objects_v2"),
call.get_paginator().paginate(Bucket="mock-bucket", Prefix="root/", Delimiter="/"),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04/', Delimiter='/'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/a'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/d'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/i'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01/t'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data01'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/l'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/r'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data03'),
call.get_paginator('list_objects_v2'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data02/s'),
call.get_paginator().paginate(Bucket='mock-bucket', Prefix='root/data04'),
]
assert s3_client_proxy.calls == expected_calls

0 comments on commit 576b6ae

Please sign in to comment.