From e030ab65217eed49a1d6f91fad1742363676136f Mon Sep 17 00:00:00 2001 From: Tilman Moeller Date: Mon, 26 Aug 2024 11:20:47 +0200 Subject: [PATCH] rohmu: add progress callback for transferred object keys Added callback progress_fn to copy_files_from which tracks the progress of transferred objects with completed and total files. This does not measure the number of bytes in order to avoid a provider based implementation. --- rohmu/object_storage/base.py | 17 +++++++++++++++-- test/object_storage/test_object_storage.py | 14 ++++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/rohmu/object_storage/base.py b/rohmu/object_storage/base.py index c50f6118..95766929 100644 --- a/rohmu/object_storage/base.py +++ b/rohmu/object_storage/base.py @@ -56,6 +56,10 @@ class IterKeyItem(NamedTuple): IncrementalProgressCallbackType = Optional[Callable[[int], None]] +class ObjectTransferProgressCallback(Protocol): + def __call__(self, files_completed: int, total_files: int) -> None: ... + + @dataclass(frozen=True, unsafe_hash=True) class ConcurrentUpload: backend: str @@ -202,10 +206,19 @@ def copy_file( cannot be copied with this method. If no metadata is given copies the existing metadata.""" raise NotImplementedError - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: + def copy_files_from( + self, + *, + source: BaseTransfer[Any], + keys: Collection[str], + progress_fn: ObjectTransferProgressCallback | None = None, + ) -> None: if isinstance(source, self.__class__): - for key in keys: + total_files = len(keys) + for index, key in enumerate(keys): self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key, timeout=15) + if progress_fn is not None: + progress_fn(index + 1, total_files) else: raise NotImplementedError diff --git a/test/object_storage/test_object_storage.py b/test/object_storage/test_object_storage.py index 646f47c0..6f5df2a3 100644 --- a/test/object_storage/test_object_storage.py +++ b/test/object_storage/test_object_storage.py @@ -4,6 +4,8 @@ from rohmu import errors from rohmu.object_storage.local import LocalTransfer from typing import Any +from unittest import mock +from unittest.mock import MagicMock import pytest @@ -69,18 +71,26 @@ def test_copy(transfer_type: str, request: Any) -> None: assert transfer.get_contents_to_string("dummy_copy_metadata") == (DUMMY_CONTENT, {"new_k": "new_v"}) -def test_copy_local_files_from(tmp_path: Path) -> None: +@pytest.mark.parametrize("with_progress_fn", [False, True]) +def test_copy_local_files_from(tmp_path: Path, with_progress_fn: bool) -> None: source = LocalTransfer(tmp_path / "source", prefix="s-prefix") destination = LocalTransfer(tmp_path / "destination", prefix="d-prefix") + mock_progress_fn = MagicMock(return_value=None) source.store_file_from_memory("some/a/key.ext", b"content_a", metadata={"info": "aaa"}) source.store_file_from_memory("some/b/key.ext", b"content_b", metadata={"info": "bbb"}) + source.store_file_from_memory("some/c/key.ext", b"content_c", metadata={"info": "ccc"}) destination.copy_files_from( source=source, - keys=["some/a/key.ext", "some/b/key.ext"], + keys=["some/a/key.ext", "some/b/key.ext", "some/c/key.ext"], + progress_fn=mock_progress_fn if with_progress_fn else None, ) + assert destination.get_contents_to_string("some/a/key.ext") == (b"content_a", {"info": "aaa", "Content-Length": "9"}) assert destination.get_contents_to_string("some/b/key.ext") == (b"content_b", {"info": "bbb", "Content-Length": "9"}) + assert destination.get_contents_to_string("some/c/key.ext") == (b"content_c", {"info": "ccc", "Content-Length": "9"}) + if with_progress_fn: + assert mock_progress_fn.call_args_list == [mock.call(1, 3), mock.call(2, 3), mock.call(3, 3)] @pytest.mark.parametrize("transfer_type", ["local_transfer"])