diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 8f8785e2a..f3ad229e6 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -188,6 +188,19 @@ def permute_multi_embedding( return permuted_values +@torch.fx.wrap +def regroup_kts( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + return torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index aa426e448..b9dd12d3b 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -18,10 +18,12 @@ from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, _regroup_keyed_tensors, KeyedJaggedTensor, KeyedTensor, permute_multi_embedding, + regroup_kts, ) from torchrec.sparse.tests.utils import build_groups, build_kts @@ -213,7 +215,7 @@ def main( ).float() groups = build_groups(kts, n_groups, duplicates=duplicates) bench( - "_regroup_keyed_tenors" + dup, + "[pytorch generic] fallback" + dup, labels, batch_size, n_dense + n_sparse, @@ -224,7 +226,7 @@ def main( profile, ) bench( - "KeyedTensor.regroup" + dup, + "[Prod] KeyedTensor.regroup" + dup, labels, batch_size, n_dense + n_sparse, @@ -235,7 +237,7 @@ def main( profile, ) bench( - "KTRegroupAsDict" + dup, + "[Module] KTRegroupAsDict" + dup, labels, batch_size, n_dense + n_sparse, @@ -248,7 +250,7 @@ def main( profile, ) bench( - "permute_multi_embs" + dup, + "[2 Ops] permute_multi_embs" + dup, labels, batch_size, n_dense + n_sparse, @@ -258,6 +260,29 @@ def main( {"keyed_tensors": kts, "groups": groups}, profile, ) + bench( + "[1 Op] KT_regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + regroup_kts, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + if not duplicates: + bench( + "[Old Prod] permute_pooled_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _fbgemm_permute_pooled_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) if __name__ == "__main__":