Skip to content

Commit

Permalink
Make get_torch_storage_id public (#2304)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored May 31, 2024
1 parent 5b360fd commit 04d18e6
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 4 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ At the moment, this module contains a single helper that takes a state dictionar
This is the underlying factory from which each framework-specific helper is derived. In practice, you are not expected to use this factory directly except if you need to adapt it to a framework that is not yet supported. If that is the case, please let us know by [opening a new issue](https://github.com/huggingface/huggingface_hub/issues/new) on the `huggingface_hub` repo.

[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory

## Helpers

### get_torch_storage_id

[[autodoc]] huggingface_hub.get_torch_storage_id
8 changes: 7 additions & 1 deletion docs/source/ko/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,10 @@ rendered properly in your Markdown viewer.

이것은 각 프레임워크별 헬퍼가 파생되는 기본 틀입니다. 실제로는 아직 지원되지 않는 프레임워크에 맞게 조정할 필요가 있는 경우가 아니면 이 틀을 직접 사용할 것으로 예상되지 않습니다. 그런 경우가 있다면, `huggingface_hub` 리포지토리에 [새로운 이슈를 개설](https://github.com/huggingface/huggingface_hub/issues/new) 하여 알려주세요.

[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory
[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory

## 도우미

### get_torch_storage_id[[huggingface_hub.get_torch_storage_id]]

[[autodoc]] huggingface_hub.get_torch_storage_id
2 changes: 2 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@
],
"serialization": [
"StateDictSplit",
"get_torch_storage_id",
"split_numpy_state_dict_into_shards",
"split_state_dict_into_shards_factory",
"split_tf_state_dict_into_shards",
Expand Down Expand Up @@ -902,6 +903,7 @@ def __dir__():
from .repository import Repository # noqa: F401
from .serialization import (
StateDictSplit, # noqa: F401
get_torch_storage_id, # noqa: F401
split_numpy_state_dict_into_shards, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
split_tf_state_dict_into_shards, # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from ._base import StateDictSplit, split_state_dict_into_shards_factory
from ._numpy import split_numpy_state_dict_into_shards
from ._tensorflow import split_tf_state_dict_into_shards
from ._torch import split_torch_state_dict_into_shards
from ._torch import get_torch_storage_id, split_torch_state_dict_into_shards
4 changes: 2 additions & 2 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def split_torch_state_dict_into_shards(
max_shard_size=max_shard_size,
filename_pattern=filename_pattern,
get_tensor_size=get_tensor_size,
get_storage_id=get_storage_id,
get_storage_id=get_torch_storage_id,
)


def get_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
"""
Return unique identifier to a tensor storage.
Expand Down

0 comments on commit 04d18e6

Please sign in to comment.