Skip to content

Commit

Permalink
Serialization: support saving torch state dict to disk (#2314)
Browse files Browse the repository at this point in the history
* Save Pytorch state dict

* fix tests

* fix test

* Add tests

* docstring

* docs

* Skip if torch not installed

* Respect custom filename pattern

* Update docs/source/en/package_reference/serialization.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
Wauplin and amyeroberts authored Jun 7, 2024
1 parent e43874a commit 122a057
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 118 deletions.
12 changes: 7 additions & 5 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ rendered properly in your Markdown viewer.

# Serialization

`huggingface_hub` contains helpers to help ML libraries to serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub.
`huggingface_hub` contains helpers to help ML libraries serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub.

## Split state dict into shards
## Save torch state dict

The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`].

At the moment, this module contains a single helper that takes a state dictionary (e.g. a mapping between layer names and related tensors) and split it into several shards, while creating a proper index in the process. This helper is available for `torch`, `tensorflow` and `numpy` tensors and is designed to be easily extended to any other ML frameworks.
[[autodoc]] huggingface_hub.save_torch_state_dict

### split_numpy_state_dict_into_shards
## Split state dict into shards

[[autodoc]] huggingface_hub.split_numpy_state_dict_into_shards
The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks.

### split_tf_state_dict_into_shards

Expand Down
4 changes: 0 additions & 4 deletions docs/source/ko/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ rendered properly in your Markdown viewer.

현재 이 모듈은 상태 딕셔너리(예: 레이어 이름과 관련 텐서 간의 매핑)를 받아 여러 샤드로 나누고, 이 과정에서 적절한 인덱스를 생성하는 단일 헬퍼를 포함하고 있습니다. 이 헬퍼는 `torch`, `tensorflow`, `numpy` 텐서에 사용 가능하며, 다른 ML 프레임워크로 쉽게 확장될 수 있도록 설계되었습니다.

### split_numpy_state_dict_into_shards[[huggingface_hub.split_numpy_state_dict_into_shards]]

[[autodoc]] huggingface_hub.split_numpy_state_dict_into_shards

### split_tf_state_dict_into_shards[[huggingface_hub.split_tf_state_dict_into_shards]]

[[autodoc]] huggingface_hub.split_tf_state_dict_into_shards
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@
"serialization": [
"StateDictSplit",
"get_torch_storage_id",
"split_numpy_state_dict_into_shards",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
"split_tf_state_dict_into_shards",
"split_torch_state_dict_into_shards",
Expand Down Expand Up @@ -904,7 +904,7 @@ def __dir__():
from .serialization import (
StateDictSplit, # noqa: F401
get_torch_storage_id, # noqa: F401
split_numpy_state_dict_into_shards, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
split_tf_state_dict_into_shards, # noqa: F401
split_torch_state_dict_into_shards, # noqa: F401
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def _as_int(value: Optional[str]) -> Optional[int]:
DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
HF_TRANSFER_CONCURRENCY = 100

# Constants for serialization

PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead
SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors"
TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5"

# Constants for safetensors repos

SAFETENSORS_SINGLE_FILE = "model.safetensors"
Expand Down
3 changes: 1 addition & 2 deletions src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@
"""Contains helpers to serialize tensors."""

from ._base import StateDictSplit, split_state_dict_into_shards_factory
from ._numpy import split_numpy_state_dict_into_shards
from ._tensorflow import split_tf_state_dict_into_shards
from ._torch import get_torch_storage_id, split_torch_state_dict_into_shards
from ._torch import get_torch_storage_id, save_torch_state_dict, split_torch_state_dict_into_shards
21 changes: 9 additions & 12 deletions src/huggingface_hub/serialization/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
TensorSizeFn_T = Callable[[TensorT], int]
StorageIDFn_T = Callable[[TensorT], Optional[Any]]

MAX_SHARD_SIZE = 5_000_000_000 # 5GB
FILENAME_PATTERN = "model{suffix}.safetensors"
MAX_SHARD_SIZE = "5GB"
SIZE_UNITS = {
"TB": 10**12,
"GB": 10**9,
"MB": 10**6,
"KB": 10**3,
}


logger = logging.get_logger(__file__)

Expand All @@ -44,8 +50,8 @@ def split_state_dict_into_shards_factory(
state_dict: Dict[str, TensorT],
*,
get_tensor_size: TensorSizeFn_T,
filename_pattern: str,
get_storage_id: StorageIDFn_T = lambda tensor: None,
filename_pattern: str = FILENAME_PATTERN,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Expand Down Expand Up @@ -75,7 +81,6 @@ def split_state_dict_into_shards_factory(
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Expand Down Expand Up @@ -172,14 +177,6 @@ def split_state_dict_into_shards_factory(
)


SIZE_UNITS = {
"TB": 10**12,
"GB": 10**9,
"MB": 10**6,
"KB": 10**3,
}


def parse_size_to_int(size_as_str: str) -> int:
"""
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
Expand Down
68 changes: 0 additions & 68 deletions src/huggingface_hub/serialization/_numpy.py

This file was deleted.

3 changes: 2 additions & 1 deletion src/huggingface_hub/serialization/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from typing import TYPE_CHECKING, Dict, Union

from .. import constants
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory


Expand All @@ -27,7 +28,7 @@
def split_tf_state_dict_into_shards(
state_dict: Dict[str, "tf.Tensor"],
*,
filename_pattern: str = "tf_model{suffix}.h5",
filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Expand Down
140 changes: 137 additions & 3 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,27 @@
"""Contains pytorch-specific helpers."""

import importlib
import json
import os
import re
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, Tuple, Union
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory


logger = logging.get_logger(__file__)

if TYPE_CHECKING:
import torch


def split_torch_state_dict_into_shards(
state_dict: Dict[str, "torch.Tensor"],
*,
filename_pattern: str = FILENAME_PATTERN,
filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Expand All @@ -38,6 +45,14 @@ def split_torch_state_dict_into_shards(
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
[6+2+2GB], [6+2GB], [6GB].
<Tip>
To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
`split_torch_state_dict_into_shards` under the hood.
</Tip>
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
Expand Down Expand Up @@ -92,6 +107,125 @@ def split_torch_state_dict_into_shards(
)


def save_torch_state_dict(
state_dict: Dict[str, "torch.Tensor"],
save_directory: Union[str, Path],
*,
safe_serialization: bool = True,
filename_pattern: Optional[str] = None,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> None:
"""
Save a model state dictionary to the disk.
The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
[`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
safetensors (the default). Otherwise, the shards are saved as pickle.
Before saving the model, the `save_directory` is cleaned from any previous shard files.
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`):
The state dictionary to save.
save_directory (`str` or `Path`):
The directory in which the model will be saved.
safe_serialization (`bool`, *optional*):
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
parameter.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Example:
```py
>>> from huggingface_hub import save_torch_state_dict
>>> model = ... # A PyTorch model
# Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
>>> state_dict = model_to_save.state_dict()
>>> save_torch_state_dict(state_dict, "path/to/folder")
```
"""
save_directory = str(save_directory)

if filename_pattern is None:
filename_pattern = (
constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
if safe_serialization
else constants.PYTORCH_WEIGHTS_FILE_PATTERN
)

# Imports correct library
if safe_serialization:
try:
from safetensors.torch import save_file as save_file_fn
except ImportError as e:
raise ImportError(
"Please install `safetensors` to use safe serialization. "
"You can install it with `pip install safetensors`."
) from e

else:
from torch import save as save_file_fn # type: ignore[assignment]

logger.warning(
"You are using unsafe serialization. Due to security reasons, it is recommended not to load "
"pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
"using safe serialization by installing `safetensors` with `pip install safetensors`."
)

# Split dict
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)

# Clean the folder from previous save
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")

# Save each shard
safe_file_kwargs = {"metadata": {"format": "pt"}} if safe_serialization else {}
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs)
logger.debug(f"Shard saved to {filename}")

# Save the index (if any)
if state_dict_split.is_sharded:
index_path = filename_pattern.format(suffix="") + ".index.json"
index = {"metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename}
with open(os.path.join(save_directory, index_path), "w") as f:
json.dump(index, f, indent=2)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
f"You can find where each parameters has been saved in the index located at {index_path}."
)

logger.info(f"Model weights successfully saved to {save_directory}!")


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
"""
Return unique identifier to a tensor storage.
Expand Down
Loading

0 comments on commit 122a057

Please sign in to comment.