Skip to content

Commit

Permalink
add wrap_file for wrapping a file object with callback (#271)
Browse files Browse the repository at this point in the history
add wrap_file for wrapping a file object with contextmanager
  • Loading branch information
skshetry authored Jan 10, 2024
1 parent cd89d9f commit 76eed47
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
5 changes: 2 additions & 3 deletions src/dvc_objects/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Sequence,
Tuple,
Union,
cast,
overload,
)
from urllib.parse import urlsplit, urlunsplit
Expand All @@ -34,8 +33,8 @@

from .callbacks import (
DEFAULT_CALLBACK,
CallbackStream,
wrap_and_branch_callback,
wrap_file,
)
from .errors import RemoteMissingDepsError

Expand Down Expand Up @@ -637,7 +636,7 @@ def put_file(
if size:
callback.set_size(size)
if hasattr(from_file, "read"):
stream = cast("BinaryIO", CallbackStream(from_file, callback))
stream = wrap_file(from_file, callback)
self.upload_fobj(stream, to_info, size=size)
else:
assert isinstance(from_file, str)
Expand Down
32 changes: 13 additions & 19 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
import asyncio
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast

import fsspec

if TYPE_CHECKING:
from typing import BinaryIO, Union
from typing import Union

from dvc_objects._tqdm import Tqdm

F = TypeVar("F", bound=Callable)


class CallbackStream:
def __init__(self, stream, callback, method="read"):
def __init__(self, stream, callback: fsspec.Callback):
self.stream = stream
if method == "write":

@wraps(stream.write)
def write(data, *args, **kwargs):
res = stream.write(data, *args, **kwargs)
callback.relative_update(len(data))
return res
@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

self.write = write
else:

@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

self.read = read
self.read = read

def __getattr__(self, attr):
return getattr(self.stream, attr)
Expand Down Expand Up @@ -181,4 +171,8 @@ def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F:
return wrap_fn(callback, branch_wrapper)


def wrap_file(file, callback: fsspec.Callback) -> BinaryIO:
return cast(BinaryIO, CallbackStream(file, callback))


DEFAULT_CALLBACK = NoOpCallback()
6 changes: 3 additions & 3 deletions src/dvc_objects/fs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc_objects.executors import ThreadPoolExecutor

from . import system
from .callbacks import DEFAULT_CALLBACK, CallbackStream
from .callbacks import DEFAULT_CALLBACK, wrap_file

if TYPE_CHECKING:
from .base import AnyFSPath, FileSystem
Expand Down Expand Up @@ -168,8 +168,8 @@ def copyfile(

callback.set_size(total)
with open(src, "rb") as fsrc, open(dest, "wb+") as fdest:
wrapped = CallbackStream(fdest, callback, "write")
shutil.copyfileobj(fsrc, wrapped, length=LOCAL_CHUNK_SIZE)
wrapped = wrap_file(fsrc, callback)
shutil.copyfileobj(wrapped, fdest, length=LOCAL_CHUNK_SIZE)


def tmp_fname(prefix: str = "") -> str:
Expand Down
15 changes: 15 additions & 0 deletions tests/fs/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TqdmCallback,
branch_callback,
wrap_and_branch_callback,
wrap_file,
wrap_fn,
)

Expand Down Expand Up @@ -146,3 +147,17 @@ async def test_wrap_and_branch_callback_async(mocker, cb_class):
m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback())
assert callback.value == 2
assert spy.call_count == 2


def test_wrap_file(memfs):
memfs.pipe_file("/file", b"foo\n")

callback = Callback()

callback.set_size(4)
with memfs.open("/file", mode="rb") as f:
wrapped = wrap_file(f, callback)
assert wrapped.read() == b"foo\n"

assert callback.value == 4
assert callback.size == 4

0 comments on commit 76eed47

Please sign in to comment.