diff --git a/torchrec/schema/api_tests/test_embedding_config_schema.py b/torchrec/schema/api_tests/test_embedding_config_schema.py new file mode 100644 index 000000000..8baefe4e7 --- /dev/null +++ b/torchrec/schema/api_tests/test_embedding_config_schema.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +import torch +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) + +from torchrec.schema.utils import is_signature_compatible + + +@dataclass +class StableEmbeddingBagConfig: + num_embeddings: int + embedding_dim: int + name: str = "" + data_type: DataType = DataType.FP32 + feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + num_embeddings_post_pruning: Optional[int] = None + + init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None + # when the position_weighted feature is in this table config, + # enable this flag to support rw_sharding + need_pos: bool = False + pooling: PoolingType = PoolingType.SUM + + +@dataclass +class StableEmbeddingConfig: + num_embeddings: int + embedding_dim: int + name: str = "" + data_type: DataType = DataType.FP32 + feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + num_embeddings_post_pruning: Optional[int] = None + + init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None + # when the position_weighted feature is in this table config, + # enable this flag to support rw_sharding + need_pos: bool = False + + +class TestEmbeddingConfig(unittest.TestCase): + def test_embedding_bag_config(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagConfig.__init__), + inspect.signature(EmbeddingBagConfig.__init__), + ) + ) + + def test_embedding_config(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingConfig.__init__), + inspect.signature(EmbeddingConfig.__init__), + ) + ) diff --git a/torchrec/schema/api_tests/test_embedding_module_schema.py b/torchrec/schema/api_tests/test_embedding_module_schema.py new file mode 100644 index 000000000..6beeba0d3 --- /dev/null +++ b/torchrec/schema/api_tests/test_embedding_module_schema.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Dict, List, Optional + +import torch +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) + +from torchrec.schema.utils import is_signature_compatible +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +class StableEmbeddingBagCollectionInterface: + """ + Stable Interface for `EmbeddingBagCollection`. + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool = False, + device: Optional[torch.device] = None, + ) -> None: + pass + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return KeyedTensor( + keys=[], + length_per_key=[], + values=torch.empty(0), + ) + + def embedding_bag_configs( + self, + ) -> List[EmbeddingBagConfig]: + return [] + + def is_weighted(self) -> bool: + return False + + +class StableEmbeddingCollectionInterface: + """ + Stable Interface for `EmbeddingBagCollection`. + """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: Optional[torch.device] = None, + need_indices: bool = False, + ) -> None: + return + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, JaggedTensor]: + return {} + + def embedding_configs( + self, + ) -> List[EmbeddingConfig]: + return [] + + def need_indices(self) -> bool: + return False + + def embedding_dim(self) -> int: + return 0 + + def embedding_names_by_table(self) -> List[List[str]]: + return [] + + +class TestEmbeddingConfig(unittest.TestCase): + def test_embedding_bag_collection(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.__init__), + inspect.signature(EmbeddingBagCollection.__init__), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.forward), + inspect.signature(EmbeddingBagCollection.forward), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature( + StableEmbeddingBagCollectionInterface.embedding_bag_configs + ), + inspect.signature(EmbeddingBagCollection.embedding_bag_configs), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.is_weighted), + inspect.signature(EmbeddingBagCollection.is_weighted), + ) + ) + + def test_embedding_collection(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.__init__), + inspect.signature(EmbeddingCollection.__init__), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.forward), + inspect.signature(EmbeddingCollection.forward), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.embedding_configs), + inspect.signature(EmbeddingCollection.embedding_configs), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.embedding_dim), + inspect.signature(EmbeddingCollection.embedding_dim), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature( + StableEmbeddingCollectionInterface.embedding_names_by_table + ), + inspect.signature(EmbeddingCollection.embedding_names_by_table), + ) + )