diff --git a/rohmu/object_storage/base.py b/rohmu/object_storage/base.py index c50f6118..a161d1ef 100644 --- a/rohmu/object_storage/base.py +++ b/rohmu/object_storage/base.py @@ -35,7 +35,7 @@ TypeVar, Union, ) -from typing_extensions import Self +from typing_extensions import Self, TypeAlias import logging import os @@ -55,6 +55,9 @@ class IterKeyItem(NamedTuple): # Argument is the additional number of bytes transferred IncrementalProgressCallbackType = Optional[Callable[[int], None]] +# Argument is the transferred object key +ObjectTransferProgressCallbackType: TypeAlias = Optional[Callable[[str], None]] + @dataclass(frozen=True, unsafe_hash=True) class ConcurrentUpload: @@ -202,10 +205,18 @@ def copy_file( cannot be copied with this method. If no metadata is given copies the existing metadata.""" raise NotImplementedError - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: + def copy_files_from( + self, + *, + source: BaseTransfer[SourceStorageModelT], + keys: Collection[str], + progress_fn: ObjectTransferProgressCallbackType = None, + ) -> None: if isinstance(source, self.__class__): for key in keys: self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key, timeout=15) + if progress_fn is not None: + progress_fn(key) else: raise NotImplementedError diff --git a/test/object_storage/test_object_storage.py b/test/object_storage/test_object_storage.py index 646f47c0..a3c9f01b 100644 --- a/test/object_storage/test_object_storage.py +++ b/test/object_storage/test_object_storage.py @@ -4,6 +4,8 @@ from rohmu import errors from rohmu.object_storage.local import LocalTransfer from typing import Any +from unittest import mock +from unittest.mock import MagicMock import pytest @@ -72,15 +74,14 @@ def test_copy(transfer_type: str, request: Any) -> None: def test_copy_local_files_from(tmp_path: Path) -> None: source = LocalTransfer(tmp_path / "source", prefix="s-prefix") destination = LocalTransfer(tmp_path / "destination", prefix="d-prefix") + mock_progress_fn = MagicMock(return_value=None) source.store_file_from_memory("some/a/key.ext", b"content_a", metadata={"info": "aaa"}) source.store_file_from_memory("some/b/key.ext", b"content_b", metadata={"info": "bbb"}) - destination.copy_files_from( - source=source, - keys=["some/a/key.ext", "some/b/key.ext"], - ) + destination.copy_files_from(source=source, keys=["some/a/key.ext", "some/b/key.ext"], progress_fn=mock_progress_fn) assert destination.get_contents_to_string("some/a/key.ext") == (b"content_a", {"info": "aaa", "Content-Length": "9"}) assert destination.get_contents_to_string("some/b/key.ext") == (b"content_b", {"info": "bbb", "Content-Length": "9"}) + assert mock_progress_fn.call_args_list == [mock.call("some/a/key.ext"), mock.call("some/b/key.ext")] @pytest.mark.parametrize("transfer_type", ["local_transfer"])