From 122a057caa10a005ef30d0258a3f8fb04fc14fb4 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 7 Jun 2024 15:38:17 +0200 Subject: [PATCH] Serialization: support saving torch state dict to disk (#2314) * Save Pytorch state dict * fix tests * fix test * Add tests * docstring * docs * Skip if torch not installed * Respect custom filename pattern * Update docs/source/en/package_reference/serialization.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../en/package_reference/serialization.md | 12 +- .../ko/package_reference/serialization.md | 4 - src/huggingface_hub/__init__.py | 4 +- src/huggingface_hub/constants.py | 6 + src/huggingface_hub/serialization/__init__.py | 3 +- src/huggingface_hub/serialization/_base.py | 21 +-- src/huggingface_hub/serialization/_numpy.py | 68 ------- .../serialization/_tensorflow.py | 3 +- src/huggingface_hub/serialization/_torch.py | 140 +++++++++++++- tests/test_serialization.py | 173 +++++++++++++++--- 10 files changed, 316 insertions(+), 118 deletions(-) delete mode 100644 src/huggingface_hub/serialization/_numpy.py diff --git a/docs/source/en/package_reference/serialization.md b/docs/source/en/package_reference/serialization.md index f63d4e343a..c2a7388091 100644 --- a/docs/source/en/package_reference/serialization.md +++ b/docs/source/en/package_reference/serialization.md @@ -4,15 +4,17 @@ rendered properly in your Markdown viewer. # Serialization -`huggingface_hub` contains helpers to help ML libraries to serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub. +`huggingface_hub` contains helpers to help ML libraries serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub. -## Split state dict into shards +## Save torch state dict + +The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`]. -At the moment, this module contains a single helper that takes a state dictionary (e.g. a mapping between layer names and related tensors) and split it into several shards, while creating a proper index in the process. This helper is available for `torch`, `tensorflow` and `numpy` tensors and is designed to be easily extended to any other ML frameworks. +[[autodoc]] huggingface_hub.save_torch_state_dict -### split_numpy_state_dict_into_shards +## Split state dict into shards -[[autodoc]] huggingface_hub.split_numpy_state_dict_into_shards +The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks. ### split_tf_state_dict_into_shards diff --git a/docs/source/ko/package_reference/serialization.md b/docs/source/ko/package_reference/serialization.md index a3b515c1e7..d026052eda 100644 --- a/docs/source/ko/package_reference/serialization.md +++ b/docs/source/ko/package_reference/serialization.md @@ -10,10 +10,6 @@ rendered properly in your Markdown viewer. 현재 이 모듈은 상태 딕셔너리(예: 레이어 이름과 관련 텐서 간의 매핑)를 받아 여러 샤드로 나누고, 이 과정에서 적절한 인덱스를 생성하는 단일 헬퍼를 포함하고 있습니다. 이 헬퍼는 `torch`, `tensorflow`, `numpy` 텐서에 사용 가능하며, 다른 ML 프레임워크로 쉽게 확장될 수 있도록 설계되었습니다. -### split_numpy_state_dict_into_shards[[huggingface_hub.split_numpy_state_dict_into_shards]] - -[[autodoc]] huggingface_hub.split_numpy_state_dict_into_shards - ### split_tf_state_dict_into_shards[[huggingface_hub.split_tf_state_dict_into_shards]] [[autodoc]] huggingface_hub.split_tf_state_dict_into_shards diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index fdbd33d6a9..41d64f9f1b 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -424,7 +424,7 @@ "serialization": [ "StateDictSplit", "get_torch_storage_id", - "split_numpy_state_dict_into_shards", + "save_torch_state_dict", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", "split_torch_state_dict_into_shards", @@ -904,7 +904,7 @@ def __dir__(): from .serialization import ( StateDictSplit, # noqa: F401 get_torch_storage_id, # noqa: F401 - split_numpy_state_dict_into_shards, # noqa: F401 + save_torch_state_dict, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 split_tf_state_dict_into_shards, # noqa: F401 split_torch_state_dict_into_shards, # noqa: F401 diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index fc6d8c5e44..4e999af621 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -37,6 +37,12 @@ def _as_int(value: Optional[str]) -> Optional[int]: DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024 HF_TRANSFER_CONCURRENCY = 100 +# Constants for serialization + +PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead +SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" +TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5" + # Constants for safetensors repos SAFETENSORS_SINGLE_FILE = "model.safetensors" diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 2d3fe3aa37..2ae8f4aa1d 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -15,6 +15,5 @@ """Contains helpers to serialize tensors.""" from ._base import StateDictSplit, split_state_dict_into_shards_factory -from ._numpy import split_numpy_state_dict_into_shards from ._tensorflow import split_tf_state_dict_into_shards -from ._torch import get_torch_storage_id, split_torch_state_dict_into_shards +from ._torch import get_torch_storage_id, save_torch_state_dict, split_torch_state_dict_into_shards diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py index e16e4a8137..c08d39b5ae 100644 --- a/src/huggingface_hub/serialization/_base.py +++ b/src/huggingface_hub/serialization/_base.py @@ -23,8 +23,14 @@ TensorSizeFn_T = Callable[[TensorT], int] StorageIDFn_T = Callable[[TensorT], Optional[Any]] -MAX_SHARD_SIZE = 5_000_000_000 # 5GB -FILENAME_PATTERN = "model{suffix}.safetensors" +MAX_SHARD_SIZE = "5GB" +SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + logger = logging.get_logger(__file__) @@ -44,8 +50,8 @@ def split_state_dict_into_shards_factory( state_dict: Dict[str, TensorT], *, get_tensor_size: TensorSizeFn_T, + filename_pattern: str, get_storage_id: StorageIDFn_T = lambda tensor: None, - filename_pattern: str = FILENAME_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ @@ -75,7 +81,6 @@ def split_state_dict_into_shards_factory( filename_pattern (`str`, *optional*): The pattern to generate the files names in which the model will be saved. Pattern must be a string that can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` - Defaults to `"model{suffix}.safetensors"`. max_shard_size (`int` or `str`, *optional*): The maximum size of each shard, in bytes. Defaults to 5GB. @@ -172,14 +177,6 @@ def split_state_dict_into_shards_factory( ) -SIZE_UNITS = { - "TB": 10**12, - "GB": 10**9, - "MB": 10**6, - "KB": 10**3, -} - - def parse_size_to_int(size_as_str: str) -> int: """ Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). diff --git a/src/huggingface_hub/serialization/_numpy.py b/src/huggingface_hub/serialization/_numpy.py deleted file mode 100644 index 19b5a26aef..0000000000 --- a/src/huggingface_hub/serialization/_numpy.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contains numpy-specific helpers.""" - -from typing import TYPE_CHECKING, Dict, Union - -from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory - - -if TYPE_CHECKING: - import numpy as np - - -def split_numpy_state_dict_into_shards( - state_dict: Dict[str, "np.ndarray"], - *, - filename_pattern: str = FILENAME_PATTERN, - max_shard_size: Union[int, str] = MAX_SHARD_SIZE, -) -> StateDictSplit: - """ - Split a model state dictionary in shards so that each shard is smaller than a given size. - - The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization - made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we - have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not - [6+2+2GB], [6+2GB], [6GB]. - - - - If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a - size greater than `max_shard_size`. - - - - Args: - state_dict (`Dict[str, np.ndarray]`): - The state dictionary to save. - filename_pattern (`str`, *optional*): - The pattern to generate the files names in which the model will be saved. Pattern must be a string that - can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` - Defaults to `"model{suffix}.safetensors"`. - max_shard_size (`int` or `str`, *optional*): - The maximum size of each shard, in bytes. Defaults to 5GB. - - Returns: - [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. - """ - return split_state_dict_into_shards_factory( - state_dict, - max_shard_size=max_shard_size, - filename_pattern=filename_pattern, - get_tensor_size=get_tensor_size, - ) - - -def get_tensor_size(tensor: "np.ndarray") -> int: - return tensor.nbytes diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py index f3818b0ae3..943ff296b4 100644 --- a/src/huggingface_hub/serialization/_tensorflow.py +++ b/src/huggingface_hub/serialization/_tensorflow.py @@ -17,6 +17,7 @@ import re from typing import TYPE_CHECKING, Dict, Union +from .. import constants from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -27,7 +28,7 @@ def split_tf_state_dict_into_shards( state_dict: Dict[str, "tf.Tensor"], *, - filename_pattern: str = "tf_model{suffix}.h5", + filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 349e7312e4..36bac7b284 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -14,12 +14,19 @@ """Contains pytorch-specific helpers.""" import importlib +import json +import os +import re from functools import lru_cache -from typing import TYPE_CHECKING, Dict, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union -from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory +from .. import constants, logging +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory +logger = logging.get_logger(__file__) + if TYPE_CHECKING: import torch @@ -27,7 +34,7 @@ def split_torch_state_dict_into_shards( state_dict: Dict[str, "torch.Tensor"], *, - filename_pattern: str = FILENAME_PATTERN, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ @@ -38,6 +45,14 @@ def split_torch_state_dict_into_shards( have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a @@ -92,6 +107,125 @@ def split_torch_state_dict_into_shards( ) +def save_torch_state_dict( + state_dict: Dict[str, "torch.Tensor"], + save_directory: Union[str, Path], + *, + safe_serialization: bool = True, + filename_pattern: Optional[str] = None, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> None: + """ + Save a model state dictionary to the disk. + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Example: + + ```py + >>> from huggingface_hub import save_torch_state_dict + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> state_dict = model_to_save.state_dict() + >>> save_torch_state_dict(state_dict, "path/to/folder") + ``` + """ + save_directory = str(save_directory) + + if filename_pattern is None: + filename_pattern = ( + constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + if safe_serialization + else constants.PYTORCH_WEIGHTS_FILE_PATTERN + ) + + # Imports correct library + if safe_serialization: + try: + from safetensors.torch import save_file as save_file_fn + except ImportError as e: + raise ImportError( + "Please install `safetensors` to use safe serialization. " + "You can install it with `pip install safetensors`." + ) from e + + else: + from torch import save as save_file_fn # type: ignore[assignment] + + logger.warning( + "You are using unsafe serialization. Due to security reasons, it is recommended not to load " + "pickled models from untrusted sources. If you intend to share your model, we strongly recommend " + "using safe serialization by installing `safetensors` with `pip install safetensors`." + ) + + # Split dict + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + + # Clean the folder from previous save + existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") + for filename in os.listdir(save_directory): + if existing_files_regex.match(filename): + try: + logger.debug(f"Removing existing file '{filename}' from folder.") + os.remove(os.path.join(save_directory, filename)) + except Exception as e: + logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...") + + # Save each shard + safe_file_kwargs = {"metadata": {"format": "pt"}} if safe_serialization else {} + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) + logger.debug(f"Shard saved to {filename}") + + # Save the index (if any) + if state_dict_split.is_sharded: + index_path = filename_pattern.format(suffix="") + ".index.json" + index = {"metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename} + with open(os.path.join(save_directory, index_path), "w") as f: + json.dump(index, f, indent=2) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). " + f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. " + f"You can find where each parameters has been saved in the index located at {index_path}." + ) + + logger.info(f"Model weights successfully saved to {save_directory}!") + + def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]: """ Return unique identifier to a tensor storage. diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 47a78d5e2e..9af9256389 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,21 +1,19 @@ +import json +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List + import pytest -from huggingface_hub.serialization import split_state_dict_into_shards_factory +from huggingface_hub.serialization import save_torch_state_dict, split_state_dict_into_shards_factory from huggingface_hub.serialization._base import parse_size_to_int -from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch from .testing_utils import requires -DUMMY_STATE_DICT = { - "layer_1": [6], - "layer_2": [10], - "layer_3": [30], - "layer_4": [2], - "layer_5": [2], -} +if TYPE_CHECKING: + import torch def _dummy_get_storage_id(item): @@ -26,9 +24,36 @@ def _dummy_get_tensor_size(item): return sum(item) -def test_single_shard(): +@pytest.fixture +def dummy_state_dict() -> Dict[str, List[int]]: + return { + "layer_1": [6], + "layer_2": [10], + "layer_3": [30], + "layer_4": [2], + "layer_5": [2], + } + + +@pytest.fixture +def torch_state_dict() -> Dict[str, "torch.Tensor"]: + try: + import torch + + return { + "layer_1": torch.tensor([4]), + "layer_2": torch.tensor([10]), + "layer_3": torch.tensor([30]), + "layer_4": torch.tensor([2]), + "layer_5": torch.tensor([2]), + } + except ImportError: + pytest.skip("torch is not available") + + +def test_single_shard(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( - DUMMY_STATE_DICT, + dummy_state_dict, get_storage_id=_dummy_get_storage_id, get_tensor_size=_dummy_get_tensor_size, max_shard_size=100, # large shard size => only one shard @@ -49,9 +74,9 @@ def test_single_shard(): assert state_dict_split.metadata == {"total_size": 50} -def test_multiple_shards(): +def test_multiple_shards(dummy_state_dict): state_dict_split = split_state_dict_into_shards_factory( - DUMMY_STATE_DICT, + dummy_state_dict, get_storage_id=_dummy_get_storage_id, get_tensor_size=_dummy_get_tensor_size, max_shard_size=10, # small shard size => multiple shards @@ -88,6 +113,7 @@ def test_tensor_same_storage(): get_storage_id=lambda x: (x[0]), # dummy for test: storage id based on first element get_tensor_size=_dummy_get_tensor_size, max_shard_size=1, + filename_pattern="model{suffix}.safetensors", ) assert state_dict_split.is_sharded assert state_dict_split.filename_to_tensors == { @@ -104,14 +130,6 @@ def test_tensor_same_storage(): assert state_dict_split.metadata == {"total_size": 3} # count them once -@requires("numpy") -def test_get_tensor_size_numpy(): - import numpy as np - - assert get_tensor_size_numpy(np.array([1, 2, 3, 4, 5], dtype=np.float64)) == 5 * 8 - assert get_tensor_size_numpy(np.array([1, 2, 3, 4, 5], dtype=np.float16)) == 5 * 2 - - @requires("tensorflow") def test_get_tensor_size_tensorflow(): import tensorflow as tf @@ -140,3 +158,116 @@ def test_parse_size_to_int(): with pytest.raises(ValueError, match="Could not parse the size value"): parse_size_to_int("1ooKB") # not a float + + +def test_save_torch_state_dict_not_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: + """Save as safetensors without sharding.""" + save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB") + assert (tmp_path / "model.safetensors").is_file() + assert not (tmp_path / "model.safetensors.index.json").is_file() + + +def test_save_torch_state_dict_sharded(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: + """Save as safetensors with sharding.""" + save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size=30) + assert not (tmp_path / "model.safetensors").is_file() + assert (tmp_path / "model.safetensors.index.json").is_file() + assert (tmp_path / "model-00001-of-00002.safetensors").is_file() + assert (tmp_path / "model-00001-of-00002.safetensors").is_file() + + assert json.loads((tmp_path / "model.safetensors.index.json").read_text("utf-8")) == { + "metadata": {"total_size": 40}, + "weight_map": { + "layer_1": "model-00001-of-00002.safetensors", + "layer_2": "model-00001-of-00002.safetensors", + "layer_3": "model-00001-of-00002.safetensors", + "layer_4": "model-00002-of-00002.safetensors", + "layer_5": "model-00002-of-00002.safetensors", + }, + } + + +def test_save_torch_state_dict_unsafe_not_sharded( + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] +) -> None: + """Save as pickle without sharding.""" + with caplog.at_level("WARNING"): + save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False) + assert "we strongly recommend using safe serialization" in caplog.text + + assert (tmp_path / "pytorch_model.bin").is_file() + assert not (tmp_path / "pytorch_model.bin.index.json").is_file() + + +def test_save_torch_state_dict_unsafe_sharded( + tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"] +) -> None: + """Save as pickle with sharding.""" + # Check logs + with caplog.at_level("WARNING"): + save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size=30, safe_serialization=False) + assert "we strongly recommend using safe serialization" in caplog.text + + assert not (tmp_path / "pytorch_model.bin").is_file() + assert (tmp_path / "pytorch_model.bin.index.json").is_file() + assert (tmp_path / "pytorch_model-00001-of-00002.bin").is_file() + assert (tmp_path / "pytorch_model-00001-of-00002.bin").is_file() + + assert json.loads((tmp_path / "pytorch_model.bin.index.json").read_text("utf-8")) == { + "metadata": {"total_size": 40}, + "weight_map": { + "layer_1": "pytorch_model-00001-of-00002.bin", + "layer_2": "pytorch_model-00001-of-00002.bin", + "layer_3": "pytorch_model-00001-of-00002.bin", + "layer_4": "pytorch_model-00002-of-00002.bin", + "layer_5": "pytorch_model-00002-of-00002.bin", + }, + } + + +def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None: + """Custom filename pattern is respected.""" + # Not sharded + save_torch_state_dict(torch_state_dict, tmp_path, filename_pattern="model.variant{suffix}.safetensors") + assert (tmp_path / "model.variant.safetensors").is_file() + + # Sharded + save_torch_state_dict( + torch_state_dict, tmp_path, filename_pattern="model.variant{suffix}.safetensors", max_shard_size=30 + ) + assert (tmp_path / "model.variant.safetensors.index.json").is_file() + assert (tmp_path / "model.variant-00001-of-00002.safetensors").is_file() + assert (tmp_path / "model.variant-00002-of-00002.safetensors").is_file() + + +def test_save_torch_state_dict_delete_existing_files( + tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"] +) -> None: + """Directory is cleaned before saving new files.""" + (tmp_path / "model.safetensors").touch() + (tmp_path / "model.safetensors.index.json").touch() + (tmp_path / "model-00001-of-00003.safetensors").touch() + (tmp_path / "model-00002-of-00003.safetensors").touch() + (tmp_path / "model-00003-of-00003.safetensors").touch() + + (tmp_path / "pytorch_model.bin").touch() + (tmp_path / "pytorch_model.bin.index.json").touch() + (tmp_path / "pytorch_model-00001-of-00003.bin").touch() + (tmp_path / "pytorch_model-00002-of-00003.bin").touch() + (tmp_path / "pytorch_model-00003-of-00003.bin").touch() + + save_torch_state_dict(torch_state_dict, tmp_path) + assert (tmp_path / "model.safetensors").stat().st_size > 0 # new file + + # Previous shards have been deleted + assert not (tmp_path / "model.safetensors.index.json").is_file() # deleted + assert not (tmp_path / "model-00001-of-00003.safetensors").is_file() # deleted + assert not (tmp_path / "model-00002-of-00003.safetensors").is_file() # deleted + assert not (tmp_path / "model-00003-of-00003.safetensors").is_file() # deleted + + # But not previous pickle files (since saving as safetensors) + assert (tmp_path / "pytorch_model.bin").is_file() # not deleted + assert (tmp_path / "pytorch_model.bin.index.json").is_file() + assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file() + assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file() + assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file()