diff --git a/sdk/src/beta9/cli/volume.py b/sdk/src/beta9/cli/volume.py index e55d324b9..dc77f8849 100644 --- a/sdk/src/beta9/cli/volume.py +++ b/sdk/src/beta9/cli/volume.py @@ -1,5 +1,6 @@ import functools import glob +from contextlib import contextmanager from pathlib import Path from typing import Iterable, List, Union, cast @@ -296,9 +297,30 @@ def upload(service: ServiceClient, local_path: Path, volume_name: str, remote_pa try: with StyledProgress() as p: task_id = p.add_task(local_path) - callback = cast(ProgressCallback, functools.partial(p.update, task_id=task_id)) - - multipart.upload(service.volume, local_path, volume_name, remote_path, callback) + progress_callback = cast(ProgressCallback, functools.partial(p.update, task_id=task_id)) + + @contextmanager + def completion_callback(): + """ + Shows progress status while the upload is being completed. + """ + p.stop() + + with terminal.progress("Completing...") as s: + yield s + + # Move cursor up 2x, clear line, and redraw the progress bar + terminal.print("\033[A\033[A\r", highlight=False) + p.start() + + multipart.upload( + service.volume, + local_path, + volume_name, + remote_path, + progress_callback, + completion_callback, + ) except KeyboardInterrupt: terminal.warn("\rUpload cancelled") diff --git a/sdk/src/beta9/multipart.py b/sdk/src/beta9/multipart.py index d2b0145ed..e381bad2c 100644 --- a/sdk/src/beta9/multipart.py +++ b/sdk/src/beta9/multipart.py @@ -15,6 +15,7 @@ from threading import Thread, local from typing import ( Callable, + ContextManager, Final, Generator, List, @@ -64,6 +65,10 @@ class ProgressCallback(Protocol): def __call__(self, total: int, advance: int) -> None: ... +class CompletionCallback(Protocol): + def __call__(self) -> ContextManager: ... + + P = ParamSpec("P") R = TypeVar("R") @@ -219,7 +224,8 @@ def upload( file_path: Path, volume_name: str, volume_path: str, - callback: Optional[ProgressCallback] = None, + progress_callback: Optional[ProgressCallback] = None, + completion_callback: Optional[CompletionCallback] = None, chunk_size: int = UPLOAD_CHUNK_SIZE, ): """ @@ -230,7 +236,9 @@ def upload( file_path: Path to the file to upload. volume_name: Name of the volume. volume_path: Path to the file on the volume. - callback: A callback that receives the total size and the number of bytes processed. + progress_callback: A callback that receives the total size and the number of + bytes processed. Defaults to None. + completion_callback: A context manager that wraps the completion of the upload. Defaults to None. chunk_size: Size of each chunk in bytes. Defaults to 4 MiB. @@ -260,7 +268,7 @@ def upload( executor = stack.enter_context(ProcessPoolExecutor(_MAX_WORKERS, initializer=_init)) queue = manager.Queue() - stack.enter_context(_progress_updater(file_size, queue, callback)) + stack.enter_context(_progress_updater(file_size, queue, progress_callback)) futures = ( executor.submit(_upload_part, file_path, part, queue) @@ -271,16 +279,23 @@ def upload( parts.sort(key=lambda part: part.number) # Complete multipart upload - completed = retry(times=3, delay=1.0)(service.complete_multipart_upload)( - CompleteMultipartUploadRequest( - upload_id=initial.upload_id, - volume_name=volume_name, - volume_path=volume_path, - completed_parts=parts, + def complete_upload(): + completed = retry(times=3, delay=1.0)(service.complete_multipart_upload)( + CompleteMultipartUploadRequest( + upload_id=initial.upload_id, + volume_name=volume_name, + volume_path=volume_path, + completed_parts=parts, + ) ) - ) - if not completed.ok: - raise CompleteMultipartUploadError(completed.err_msg) + if not completed.ok: + raise CompleteMultipartUploadError(completed.err_msg) + + if completion_callback is not None: + with completion_callback(): + complete_upload() + else: + complete_upload() except (Exception, KeyboardInterrupt): service.abort_multipart_upload(