From 9cb50c58df657721c567a1b1915d9f9584fa2da4 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 2 Dec 2024 19:01:29 +0100 Subject: [PATCH] fix test for torch<=2.1.0 --- src/huggingface_hub/__init__.py | 9 ++++++--- src/huggingface_hub/serialization/_torch.py | 11 +++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 67ffda9326..f53582dd34 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -460,14 +460,14 @@ "get_tf_storage_size", "get_torch_storage_id", "get_torch_storage_size", + "load_sharded_checkpoint", + "load_state_dict_from_file", + "load_torch_model", "save_torch_model", "save_torch_state_dict", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", - "load_torch_model", - "load_sharded_checkpoint", - "load_state_dict_from_file", ], "utils": [ "CacheNotFound", @@ -988,6 +988,9 @@ def __dir__(): get_tf_storage_size, # noqa: F401 get_torch_storage_id, # noqa: F401 get_torch_storage_size, # noqa: F401 + load_sharded_checkpoint, # noqa: F401 + load_state_dict_from_file, # noqa: F401 + load_torch_model, # noqa: F401 save_torch_model, # noqa: F401 save_torch_state_dict, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 0b3d91d16f..aedd46f9ca 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -23,6 +23,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union +from packaging import version + from .. import constants, logging from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -600,7 +602,7 @@ def load_state_dict_from_file( ) from e # Check format of the archive - with safe_open(checkpoint_file, framework="pt") as f: + with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] metadata = f.metadata() if metadata.get("format") != "pt": raise OSError( @@ -609,18 +611,19 @@ def load_state_dict_from_file( ) return load_file(checkpoint_file) try: + import torch from torch import load except ImportError as e: raise ImportError( "Please install `safetensors` to load safetensors checkpoint. " "You can install it with `pip install safetensors`." ) from e - + additional_kwargs = {"mmap": mmap} if version.parse(torch.__version__) >= version.parse("2.1.0") else {} return load( checkpoint_file, map_location=map_location, weights_only=weights_only, - mmap=mmap, + **additional_kwargs, ) @@ -644,7 +647,7 @@ def _load_shard_into_memory( The loaded state dict for this shard """ try: - state_dict = load_fn(shard_path, **kwargs) + state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type] yield state_dict finally: # Explicitly delete the state dict