diff --git a/docs/source/en/package_reference/serialization.md b/docs/source/en/package_reference/serialization.md index 545b486cee..f63d4e343a 100644 --- a/docs/source/en/package_reference/serialization.md +++ b/docs/source/en/package_reference/serialization.md @@ -27,3 +27,9 @@ At the moment, this module contains a single helper that takes a state dictionar This is the underlying factory from which each framework-specific helper is derived. In practice, you are not expected to use this factory directly except if you need to adapt it to a framework that is not yet supported. If that is the case, please let us know by [opening a new issue](https://github.com/huggingface/huggingface_hub/issues/new) on the `huggingface_hub` repo. [[autodoc]] huggingface_hub.split_state_dict_into_shards_factory + +## Helpers + +### get_torch_storage_id + +[[autodoc]] huggingface_hub.get_torch_storage_id \ No newline at end of file diff --git a/docs/source/ko/package_reference/serialization.md b/docs/source/ko/package_reference/serialization.md index c3bc14f04a..a3b515c1e7 100644 --- a/docs/source/ko/package_reference/serialization.md +++ b/docs/source/ko/package_reference/serialization.md @@ -26,4 +26,10 @@ rendered properly in your Markdown viewer. 이것은 각 프레임워크별 헬퍼가 파생되는 기본 틀입니다. 실제로는 아직 지원되지 않는 프레임워크에 맞게 조정할 필요가 있는 경우가 아니면 이 틀을 직접 사용할 것으로 예상되지 않습니다. 그런 경우가 있다면, `huggingface_hub` 리포지토리에 [새로운 이슈를 개설](https://github.com/huggingface/huggingface_hub/issues/new) 하여 알려주세요. -[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory \ No newline at end of file +[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory + +## 도우미 + +### get_torch_storage_id[[huggingface_hub.get_torch_storage_id]] + +[[autodoc]] huggingface_hub.get_torch_storage_id \ No newline at end of file diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 1c9dc59559..fdbd33d6a9 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -423,6 +423,7 @@ ], "serialization": [ "StateDictSplit", + "get_torch_storage_id", "split_numpy_state_dict_into_shards", "split_state_dict_into_shards_factory", "split_tf_state_dict_into_shards", @@ -902,6 +903,7 @@ def __dir__(): from .repository import Repository # noqa: F401 from .serialization import ( StateDictSplit, # noqa: F401 + get_torch_storage_id, # noqa: F401 split_numpy_state_dict_into_shards, # noqa: F401 split_state_dict_into_shards_factory, # noqa: F401 split_tf_state_dict_into_shards, # noqa: F401 diff --git a/src/huggingface_hub/serialization/__init__.py b/src/huggingface_hub/serialization/__init__.py index 0bb6c2d0a1..2d3fe3aa37 100644 --- a/src/huggingface_hub/serialization/__init__.py +++ b/src/huggingface_hub/serialization/__init__.py @@ -17,4 +17,4 @@ from ._base import StateDictSplit, split_state_dict_into_shards_factory from ._numpy import split_numpy_state_dict_into_shards from ._tensorflow import split_tf_state_dict_into_shards -from ._torch import split_torch_state_dict_into_shards +from ._torch import get_torch_storage_id, split_torch_state_dict_into_shards diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 7ccce3c281..349e7312e4 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -88,11 +88,11 @@ def split_torch_state_dict_into_shards( max_shard_size=max_shard_size, filename_pattern=filename_pattern, get_tensor_size=get_tensor_size, - get_storage_id=get_storage_id, + get_storage_id=get_torch_storage_id, ) -def get_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]: +def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]: """ Return unique identifier to a tensor storage.