diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index ae0112e76d..3c8b2b33f3 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -20,7 +20,7 @@ from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from .. import constants, logging from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -336,17 +336,24 @@ def split_torch_state_dict_into_shards( ) -def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]: +def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """Returns a unique id for plain tensor + or a (potentially nested) Tuple of unique id for the flattened Tensor + if the input is a wrapper tensor subclass Tensor """ - Return unique identifier to a tensor storage. - Multiple different tensors can share the same underlying storage. For - example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is - guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with - non-overlapping lifetimes may have the same id. + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs) + + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass - Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. - """ if tensor.device.type == "xla" and is_torch_tpu_available(): # NOTE: xla tensors dont have storage # use some other unique id to distinguish. @@ -358,13 +365,38 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i else: unique_id = storage_ptr(tensor) - return tensor.device, unique_id, get_torch_storage_size(tensor) + return unique_id + + +def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]: + """ + Return unique identifier to a tensor storage. + + Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. + """ + return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) def get_torch_storage_size(tensor: "torch.Tensor") -> int: """ Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + try: return tensor.untyped_storage().nbytes() except AttributeError: @@ -398,10 +430,20 @@ def is_torch_tpu_available(check_device=True): return False -def storage_ptr(tensor: "torch.Tensor") -> int: +def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + return _get_unique_id(tensor) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + try: return tensor.untyped_storage().data_ptr() except Exception: @@ -496,6 +538,17 @@ def _is_complete(tensor: "torch.Tensor") -> bool: """ Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return all(_is_complete(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( tensor.dtype ) == get_torch_storage_size(tensor) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index d954fce99b..019ec26f2d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -7,12 +7,14 @@ import pytest from pytest_mock import MockerFixture +from huggingface_hub import constants from huggingface_hub.serialization import ( get_tf_storage_size, get_torch_storage_size, save_torch_model, save_torch_state_dict, split_state_dict_into_shards_factory, + split_torch_state_dict_into_shards, ) from huggingface_hub.serialization._base import parse_size_to_int @@ -31,6 +33,16 @@ def _dummy_get_storage_size(item): return sum(item) +# util functions for checking the version for pytorch +def is_wrapper_tensor_subclass_available(): + try: + from torch.utils._python_dispatch import is_traceable_wrapper_subclass # noqa: F401 + + return True + except ImportError: + return False + + @pytest.fixture def dummy_state_dict() -> Dict[str, List[int]]: return { @@ -58,6 +70,25 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]: pytest.skip("torch is not available") +@pytest.fixture +def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]: + try: + import torch + from torch.testing._internal.two_tensor import TwoTensor + + t = torch.tensor([4]) + return { + "layer_1": torch.tensor([4]), + "layer_2": torch.tensor([10]), + "layer_3": torch.tensor([30]), + "layer_4": torch.tensor([2]), + "layer_5": torch.tensor([2]), + "layer_6": TwoTensor(t, t), + } + except ImportError: + pytest.skip("torch is not available") + + @pytest.fixture def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: try: @@ -75,6 +106,31 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]: pytest.skip("torch is not available") +@pytest.fixture +def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]: + try: + import torch + from torch.testing._internal.two_tensor import TwoTensor + + t = torch.tensor([4]) + tensor_subclass_tensor = TwoTensor(t, t) + + t = torch.tensor([4]) + shared_tensor_subclass_tensor = TwoTensor(t, t) + return { + "layer_1": torch.tensor([4]), + "layer_2": torch.tensor([10]), + "layer_3": torch.tensor([30]), + "layer_4": torch.tensor([2]), + "layer_5": torch.tensor([2]), + "layer_6": tensor_subclass_tensor, + "ts_shared_1": shared_tensor_subclass_tensor, + "ts_shared_2": shared_tensor_subclass_tensor, + } + except ImportError: + pytest.skip("torch is not available") + + def test_single_shard(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( dummy_state_dict, @@ -170,6 +226,18 @@ def test_get_torch_storage_size(): assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 +@requires("torch") +@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") +def test_get_torch_storage_size_wrapper_tensor_subclass(): + import torch + from torch.testing._internal.two_tensor import TwoTensor + + t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64) + assert get_torch_storage_size(TwoTensor(t, t)) == 5 * 8 * 2 + t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16) + assert get_torch_storage_size(TwoTensor(t, TwoTensor(t, t))) == 5 * 2 * 3 + + def test_parse_size_to_int(): assert parse_size_to_int("1KB") == 1 * 10**3 assert parse_size_to_int("2MB") == 2 * 10**6 @@ -247,6 +315,38 @@ def test_save_torch_state_dict_unsafe_not_sharded( assert not (tmp_path / "pytorch_model.bin.index.json").is_file() +@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") +def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded( + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"] +) -> None: + """Save as pickle without sharding.""" + with caplog.at_level("WARNING"): + save_torch_state_dict( + torch_state_dict_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False + ) + assert "we strongly recommend using safe serialization" in caplog.text + + assert (tmp_path / "pytorch_model.bin").is_file() + assert not (tmp_path / "pytorch_model.bin.index.json").is_file() + + +@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") +def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, + torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"], +) -> None: + """Save as pickle without sharding.""" + with caplog.at_level("WARNING"): + save_torch_state_dict( + torch_state_dict_shared_layers_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False + ) + assert "we strongly recommend using safe serialization" in caplog.text + + assert (tmp_path / "pytorch_model.bin").is_file() + assert not (tmp_path / "pytorch_model.bin.index.json").is_file() + + def test_save_torch_state_dict_unsafe_sharded( tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] ) -> None: @@ -314,6 +414,18 @@ def test_save_torch_state_dict_shared_layers_sharded( assert "shared_2" not in state_dict +def test_split_torch_state_dict_into_shards( + tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"] +): + # the model size is 72, setting max_shard_size to 32 means we'll shard the file + state_dict_split = split_torch_state_dict_into_shards( + torch_state_dict_shared_layers_tensor_subclass, + filename_pattern=constants.PYTORCH_WEIGHTS_FILE_PATTERN, + max_shard_size=32, + ) + assert state_dict_split.is_sharded + + def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: """Custom filename pattern is respected.""" # Not sharded