Skip to content

Commit

Permalink
fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Dec 2, 2024
1 parent 5075b8f commit 259c8f3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
24 changes: 12 additions & 12 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,18 @@ def load_torch_model(
model (`torch.nn.Module`):
The model in which to load the checkpoint.
checkpoint_path (`str` or `os.PathLike`):
Path to either the checkpoint file or directory containing sharded checkpoints.
Path to either the checkpoint file or directory containing the checkpoint(s).
strict (`bool`, *optional*, defaults to `False`):
Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
safe (`bool`, *optional*, defaults to `True`):
If True, use safetensors for loading when possible. Otherwise, use torch.load.
If both safetensors and PyTorch save files are present in checkpoint and `safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
weights_only (`bool`, *optional*, defaults to `False`):
If True, only loads the model weights without optimizer states and other metadata.
Only supported in PyTorch >= 1.13.
map_location (`str` or `torch.device`, *optional*):
A string or torch.device specifying how to remap storage locations.
A `torch.device` object, string or a dict specifying how to remap storage locations. It
indicates the location where all tensors should be loaded.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
Expand Down Expand Up @@ -424,7 +426,7 @@ def load_torch_model(
weights_only=weights_only,
)

# Look for single model file (safetensors or pytorch)
# Look for single model file
model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
if len(model_files) == 1:
state_dict = load_state_dict_from_file(
Expand Down Expand Up @@ -454,9 +456,7 @@ def load_sharded_checkpoint(
"""
Loads a sharded checkpoint into a model. This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
Each shard is loaded one by one in RAM and deleted after being loaded into the model.
but for a sharded checkpoint. Each shard is loaded one by one and deleted after being loaded into the model.
Args:
model (`torch.nn.Module`):
Expand All @@ -465,8 +465,8 @@ def load_sharded_checkpoint(
A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional*, defaults to `False`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
safe (`bool`, *optional*, defaults to `False`):
If both safetensors and PyTorch save files are present in checkpoint and `safe_deserialization` is True, the
safe (`bool`, *optional*, defaults to `True`):
If both safetensors and PyTorch save files are present in checkpoint and `safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
weights_only (`bool`, *optional*, defaults to `False`):
If True, only loads the model weights without optimizer states and other metadata.
Expand Down Expand Up @@ -528,11 +528,11 @@ def load_state_dict_from_file(
mmap: bool = False,
) -> Union[Dict[str, "torch.Tensor"], Any]:
"""
Loads a checkpoint file, handling both safetensors and PyTorch checkpoint formats.
Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
Args:
checkpoint_file (`str` or `os.PathLike`):
Path to the checkpoint file to load. Can be either a safetensors or PyTorch checkpoint.
Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
map_location (`str` or `torch.device`, *optional*):
A `torch.device` object, string or a dict specifying how to remap storage locations. It
indicates the location where all tensors should be loaded.
Expand All @@ -546,7 +546,7 @@ def load_state_dict_from_file(
Returns:
`Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
- For safetensors files: always returns a dictionary mapping parameter names to tensors.
- For PyTorch files: returns any Python object that was pickled (commonly a state dict, but could be
- For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
an entire model, optimizer state, or any other Python object).
Raises:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class DummyModel(torch.nn.Module):

def __init__(self):
super().__init__()
# Register parameters with float tensors (not integers)
self.register_parameter("layer_1", torch.nn.Parameter(torch.tensor([4.0])))
self.register_parameter("layer_2", torch.nn.Parameter(torch.tensor([10.0])))
self.register_parameter("layer_3", torch.nn.Parameter(torch.tensor([30.0])))
Expand Down Expand Up @@ -556,7 +555,6 @@ def test_load_sharded_state_dict(
"""Test saving and loading a sharded state dict."""
import torch

# Save with small shard size to force sharding
save_torch_state_dict(
torch_state_dict,
save_directory=tmp_path,
Expand Down Expand Up @@ -643,7 +641,7 @@ def test_load_sharded_model_strict_mode(tmp_path, torch_state_dict, dummy_model,
save_torch_state_dict(
modified_dict,
save_directory=tmp_path,
max_shard_size=30, # Small size to force sharding
max_shard_size=30,
)

if strict:
Expand Down

0 comments on commit 259c8f3

Please sign in to comment.