Skip to content

Commit

Permalink
[HfFileSystem] Copy non lfs files (#1996)
Browse files Browse the repository at this point in the history
* copy non lfs files

* mypy

* add test

* fix test

* Apply suggestions from code review

Co-authored-by: Lucain <lucainp@gmail.com>

* fix import

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
lhoestq and Wauplin authored Feb 15, 2024
1 parent c524a86 commit 434c60c
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
69 changes: 49 additions & 20 deletions src/huggingface_hub/_commit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from huggingface_hub import get_session

from .constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER
from .file_download import hf_hub_url
from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
from .utils import (
EntryNotFoundError,
Expand Down Expand Up @@ -521,16 +522,19 @@ def _fetch_upload_modes(


@validate_hf_hub_args
def _fetch_lfs_files_to_copy(
def _fetch_files_to_copy(
copies: Iterable[CommitOperationCopy],
repo_type: str,
repo_id: str,
token: Optional[str],
revision: str,
endpoint: Optional[str] = None,
) -> Dict[Tuple[str, Optional[str]], "RepoFile"]:
) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]:
"""
Requests the Hub files information of the LFS files to be copied, including their sha256.
Fetch information about the files to copy.
For LFS files, we only need their metadata (file size and sha256) while for regular files
we need to download the raw content from the Hub.
Args:
copies (`Iterable` of :class:`CommitOperationCopy`):
Expand All @@ -546,8 +550,9 @@ def _fetch_lfs_files_to_copy(
revision (`str`):
The git revision to upload the files to. Can be any valid git revision.
Returns: `Dict[Tuple[str, Optional[str]], RepoFile]]`
Key is the file path and revision of the file to copy, value is the repo file.
Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]`
Key is the file path and revision of the file to copy.
Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files).
Raises:
[`~utils.HfHubHTTPError`]
Expand All @@ -558,7 +563,7 @@ def _fetch_lfs_files_to_copy(
from .hf_api import HfApi, RepoFolder

hf_api = HfApi(endpoint=endpoint, token=token)
files_to_copy = {}
files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {}
for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
operations = list(operations) # type: ignore
paths = [op.src_path_in_repo for op in operations]
Expand All @@ -572,9 +577,21 @@ def _fetch_lfs_files_to_copy(
for src_repo_file in src_repo_files:
if isinstance(src_repo_file, RepoFolder):
raise NotImplementedError("Copying a folder is not implemented.")
if not src_repo_file.lfs:
raise NotImplementedError("Copying a non-LFS file is not implemented")
files_to_copy[(src_repo_file.rfilename, src_revision)] = src_repo_file
if src_repo_file.lfs:
files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file
else:
# TODO: (optimization) download regular files to copy concurrently
headers = build_hf_headers(token=token)
url = hf_hub_url(
endpoint=endpoint,
repo_type=repo_type,
repo_id=repo_id,
revision=src_revision or revision,
filename=src_repo_file.path,
)
response = get_session().get(url, headers=headers)
hf_raise_for_status(response)
files_to_copy[(src_repo_file.path, src_revision)] = response.content
for operation in operations:
if (operation.src_path_in_repo, src_revision) not in files_to_copy:
raise EntryNotFoundError(
Expand All @@ -586,7 +603,7 @@ def _fetch_lfs_files_to_copy(

def _prepare_commit_payload(
operations: Iterable[CommitOperation],
files_to_copy: Dict[Tuple[str, Optional[str]], "RepoFile"],
files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]],
commit_message: str,
commit_description: Optional[str] = None,
parent_commit: Optional[str] = None,
Expand Down Expand Up @@ -649,16 +666,28 @@ def _prepare_commit_payload(
# 2.d. Case copying a file or folder
elif isinstance(operation, CommitOperationCopy):
file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
if not file_to_copy.lfs:
raise NotImplementedError("Copying a non-LFS file is not implemented")
yield {
"key": "lfsFile",
"value": {
"path": operation.path_in_repo,
"algo": "sha256",
"oid": file_to_copy.lfs["sha256"],
},
}
if isinstance(file_to_copy, bytes):
yield {
"key": "file",
"value": {
"content": base64.b64encode(file_to_copy).decode(),
"path": operation.path_in_repo,
"encoding": "base64",
},
}
elif file_to_copy.lfs:
yield {
"key": "lfsFile",
"value": {
"path": operation.path_in_repo,
"algo": "sha256",
"oid": file_to_copy.lfs.sha256,
},
}
else:
raise ValueError(
"Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info."
)
# 2.e. Never expected to happen
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
CommitOperationAdd,
CommitOperationCopy,
CommitOperationDelete,
_fetch_lfs_files_to_copy,
_fetch_files_to_copy,
_fetch_upload_modes,
_prepare_commit_payload,
_upload_lfs_files,
Expand Down Expand Up @@ -3614,7 +3614,7 @@ def create_commit(
num_threads=num_threads,
free_memory=False, # do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for "normal" users
)
files_to_copy = _fetch_lfs_files_to_copy(
files_to_copy = _fetch_files_to_copy(
copies=copies,
repo_type=repo_type,
repo_id=repo_id,
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwar
resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
)

if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None:
if same_repo:
commit_message = f"Copy {path1} to {path2}"
self._api.create_commit(
repo_id=resolved_path1.repo_id,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,12 +899,11 @@ def test_commit_copy_file(self, repo_url: RepoUrl) -> None:
CommitOperationCopy(src_path_in_repo="lfs.bin", path_in_repo="lfs Copy (1).bin"),
],
)
with self.assertRaises(NotImplementedError):
self._api.create_commit(
repo_id=repo_id,
commit_message="Copy regular file.",
operations=[CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file Copy.txt")],
)
self._api.create_commit(
repo_id=repo_id,
commit_message="Copy regular file.",
operations=[CommitOperationCopy(src_path_in_repo="file.txt", path_in_repo="file Copy.txt")],
)
with self.assertRaises(EntryNotFoundError):
self._api.create_commit(
repo_id=repo_id,
Expand All @@ -917,6 +916,7 @@ def test_commit_copy_file(self, repo_url: RepoUrl) -> None:
# Check repo files
repo_files = self._api.list_repo_files(repo_id=repo_id)
self.assertIn("file.txt", repo_files)
self.assertIn("file Copy.txt", repo_files)
self.assertIn("lfs.bin", repo_files)
self.assertIn("lfs Copy.bin", repo_files)
self.assertIn("lfs Copy (1).bin", repo_files)
Expand Down

0 comments on commit 434c60c

Please sign in to comment.