Skip to content

Commit

Permalink
https://github.com/andgineer/aws-s3/issues/3
Browse files Browse the repository at this point in the history
limit level of folder recursion
  • Loading branch information
andgineer committed Jun 3, 2024
1 parent 1be8dfc commit 81adac2
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions src/aws_s3/list_objects_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import functools
from typing import Iterable, Any, Dict
from typing import Iterable, Any, Dict, Optional, Tuple, List

import aiobotocore.session
import aiobotocore.client
Expand All @@ -27,47 +27,60 @@ def __init__(self, bucket: str) -> None:
self._bucket = bucket

async def _list_objects(
self, s3_client: aiobotocore.client.AioBaseClient, prefix: str
) -> Dict[str, Any]:
self,
s3_client: aiobotocore.client.AioBaseClient,
prefix: str,
current_depth: int,
max_depth: Optional[int],
) -> Tuple[Iterable[Dict[str, Any]], Iterable[Tuple[str, int]]]:
paginator = s3_client.get_paginator("list_objects_v2")
objects = []
prefixes = []

async for page in paginator.paginate(Bucket=self._bucket, Prefix=prefix, Delimiter="/"):
params = {"Bucket": self._bucket, "Prefix": prefix}
if max_depth is None or current_depth < max_depth:
params["Delimiter"] = "/"

async for page in paginator.paginate(**params):
for obj in page.get("Contents", []):
key: str = obj["Key"]
if key.endswith("/"):
continue # Omit "directories"

objects.append(obj)

prefixes.extend(page.get("CommonPrefixes", [])) # add "subdirectories"
if "Delimiter" in params:
prefixes.extend(
[
(prefix["Prefix"], current_depth + 1)
for prefix in page.get("CommonPrefixes", [])
]
)

return {"Objects": objects, "CommonPrefixes": prefixes}
return objects, prefixes

async def list_objects(
self,
prefix: str = "/",
self, prefix: str = "/", max_depth: Optional[int] = None
) -> Iterable[Dict[str, Any]]:
"""List all objects in the bucket with given prefix."""
objects = []
objects: List[Dict[str, Any]] = []
tasks = set()

async with get_s3_client() as s3_client:
tasks.add(asyncio.create_task(self._list_objects(s3_client, prefix)))
tasks.add(asyncio.create_task(self._list_objects(s3_client, prefix, 0, max_depth)))

while tasks:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
tasks = pending

for task in done:
result = await task
objects.extend(result["Objects"])
files, folders = await task
objects.extend(files)

for common_prefix in result["CommonPrefixes"]:
for folder, level in folders:
tasks.add(
asyncio.create_task(
self._list_objects(s3_client, common_prefix["Prefix"])
self._list_objects(s3_client, folder, level, max_depth)
)
)

Expand Down

0 comments on commit 81adac2

Please sign in to comment.