Skip to content

Commit

Permalink
fix test for torch<=2.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Dec 2, 2024
1 parent 259c8f3 commit 9cb50c5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,14 @@
"get_tf_storage_size",
"get_torch_storage_id",
"get_torch_storage_size",
"load_sharded_checkpoint",
"load_state_dict_from_file",
"load_torch_model",
"save_torch_model",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
"split_tf_state_dict_into_shards",
"split_torch_state_dict_into_shards",
"load_torch_model",
"load_sharded_checkpoint",
"load_state_dict_from_file",
],
"utils": [
"CacheNotFound",
Expand Down Expand Up @@ -988,6 +988,9 @@ def __dir__():
get_tf_storage_size, # noqa: F401
get_torch_storage_id, # noqa: F401
get_torch_storage_size, # noqa: F401
load_sharded_checkpoint, # noqa: F401
load_state_dict_from_file, # noqa: F401
load_torch_model, # noqa: F401
save_torch_model, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
Expand Down
11 changes: 7 additions & 4 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union

from packaging import version

from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory

Expand Down Expand Up @@ -600,7 +602,7 @@ def load_state_dict_from_file(
) from e

# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
metadata = f.metadata()
if metadata.get("format") != "pt":
raise OSError(
Expand All @@ -609,18 +611,19 @@ def load_state_dict_from_file(
)
return load_file(checkpoint_file)
try:
import torch
from torch import load
except ImportError as e:
raise ImportError(
"Please install `safetensors` to load safetensors checkpoint. "
"You can install it with `pip install safetensors`."
) from e

additional_kwargs = {"mmap": mmap} if version.parse(torch.__version__) >= version.parse("2.1.0") else {}
return load(
checkpoint_file,
map_location=map_location,
weights_only=weights_only,
mmap=mmap,
**additional_kwargs,
)


Expand All @@ -644,7 +647,7 @@ def _load_shard_into_memory(
The loaded state dict for this shard
"""
try:
state_dict = load_fn(shard_path, **kwargs)
state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type]
yield state_dict
finally:
# Explicitly delete the state dict
Expand Down

0 comments on commit 9cb50c5

Please sign in to comment.