From b31daed4111bf7601e84d20daf6fbf6880c94158 Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Thu, 23 May 2024 14:48:22 -0700 Subject: [PATCH] Add VBE to Dense TBE frontend (#2628) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2628 - add frontend support to Dense TBE module to support VBE - add unit test Differential Revision: D56651380 --- ...t_table_batched_embeddings_ops_training.py | 278 +++++++++++------- .../test/tbe/training/backward_dense_test.py | 77 ++++- 2 files changed, 229 insertions(+), 126 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a9651e94ac..69163dd557 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -307,6 +307,126 @@ def apply_split_helper( ) +def generate_vbe_metadata( + offsets: Tensor, + batch_size_per_feature_per_rank: Optional[List[List[int]]], + optimizer: OptimType, + pooling_mode: PoolingMode, + feature_dims: Tensor, + device: torch.device, +) -> invokers.lookup_args.VBEMetadata: + """ + Generate VBE metadata based on batch_size_per_feature_per_rank. + Metadata includes: + 1) B_offsets - A tensor that contains batch size offsets for each + feature + 2) output_offsets_feature_rank - A tensor that contains output + offsets for each feature + 3) B_offsets_per_rank_per_feature - A tensor that contains batch + size offsets for each feature + and rank + 4) max_B - The maximum batch size for all features + 5) max_B_feature_rank - The maximum batch size for all ranks and + features + 6) output_size - The output size (number of elements) + """ + if batch_size_per_feature_per_rank is not None: + assert ( + optimizer == OptimType.EXACT_ROWWISE_ADAGRAD + or optimizer == OptimType.EXACT_SGD + or optimizer == OptimType.NONE + ), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only" + assert ( + pooling_mode != PoolingMode.NONE + ), "Variable batch size TBE support is not enabled for PoolingMode.NONE" + # TODO: Add input check + zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32) + + # Create B offsets + total_batch_size_per_feature = torch.tensor( + batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu" + ).sum(dim=1) + + max_B = total_batch_size_per_feature.max().item() + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + torch._check_is_size(max_B) + torch._check(max_B < offsets.numel()) + + Bs = torch.concat([zero_tensor, total_batch_size_per_feature]) + B_offsets = Bs.cumsum(dim=0).to(torch.int) + + # Create output offsets + B_feature_rank = torch.tensor( + batch_size_per_feature_per_rank, + device="cpu", + dtype=torch.int64, + ) + max_B_feature_rank = B_feature_rank.max().item() + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + torch._check_is_size(max_B_feature_rank) + torch._check(max_B_feature_rank <= offsets.size(0)) + # D->H only once + feature_dims = feature_dims.cpu() + output_sizes_feature_rank = B_feature_rank.transpose(0, 1) * feature_dims.view( + 1, -1 + ) + output_offsets_feature_rank = torch.concat( + [ + zero_tensor.to(torch.int64), + output_sizes_feature_rank.flatten().cumsum(dim=0), + ] + ) + output_size = output_offsets_feature_rank[-1].item() + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + torch._check_is_size(output_size) + + # TODO: Support INT8 output + # B_offsets_rank_per_feature is for rank and (b, t) mapping + B_offsets_rank_per_feature = ( + torch.tensor( + [ + [0] + batch_size_per_feature + for batch_size_per_feature in batch_size_per_feature_per_rank + ], + device="cpu", + dtype=torch.int32, + ) + .cumsum(dim=1) + .to(torch.int) + ) + + B_offsets = B_offsets.to(device, non_blocking=True) + output_offsets_feature_rank = output_offsets_feature_rank.to( + device, non_blocking=True + ) + B_offsets_rank_per_feature = B_offsets_rank_per_feature.to( + device, non_blocking=True + ) + + # TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank + vbe_metadata = invokers.lookup_args.VBEMetadata( + B_offsets=B_offsets, + output_offsets_feature_rank=output_offsets_feature_rank, + B_offsets_rank_per_feature=B_offsets_rank_per_feature, + # pyre-ignore + max_B=max_B, + # pyre-ignore + max_B_feature_rank=max_B_feature_rank, + # pyre-ignore + output_size=output_size, + ) + else: + vbe_metadata = invokers.lookup_args.VBEMetadata( + B_offsets=None, + output_offsets_feature_rank=None, + B_offsets_rank_per_feature=None, + max_B=-1, + max_B_feature_rank=-1, + output_size=-1, + ) + return vbe_metadata + + # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized. # pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized. class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): @@ -1092,120 +1212,20 @@ def _report_io_size_count(self, event: str, data: Tensor) -> Tensor: ) return data + @torch.jit.ignore def _generate_vbe_metadata( self, offsets: Tensor, batch_size_per_feature_per_rank: Optional[List[List[int]]], ) -> invokers.lookup_args.VBEMetadata: - """ - Generate VBE metadata based on batch_size_per_feature_per_rank. - Metadata includes: - 1) B_offsets - A tensor that contains batch size offsets for each - feature - 2) output_offsets_feature_rank - A tensor that contains output - offsets for each feature - 3) B_offsets_per_rank_per_feature - A tensor that contains batch - size offsets for each feature - and rank - 4) max_B - The maximum batch size for all features - 5) max_B_feature_rank - The maximum batch size for all ranks and - features - 6) output_size - The output size (number of elements) - """ - if batch_size_per_feature_per_rank is not None: - assert ( - self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD - or self.optimizer == OptimType.EXACT_SGD - ), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only" - assert ( - self.pooling_mode != PoolingMode.NONE.value - ), "Variable batch size TBE support is not enabled for PoolingMode.NONE" - # TODO: Add input check - zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32) - - # Create B offsets - total_batch_size_per_feature = torch.tensor( - batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu" - ).sum(dim=1) - - max_B = total_batch_size_per_feature.max().item() - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - torch._check_is_size(max_B) - torch._check(max_B < offsets.numel()) - - Bs = torch.concat([zero_tensor, total_batch_size_per_feature]) - B_offsets = Bs.cumsum(dim=0).to(torch.int) - - # Create output offsets - B_feature_rank = torch.tensor( - batch_size_per_feature_per_rank, - device="cpu", - dtype=torch.int64, - ) - max_B_feature_rank = B_feature_rank.max().item() - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - torch._check_is_size(max_B_feature_rank) - torch._check(max_B_feature_rank <= offsets.size(0)) - # D->H only once - self.feature_dims = self.feature_dims.cpu() - output_sizes_feature_rank = B_feature_rank.transpose( - 0, 1 - ) * self.feature_dims.view(1, -1) - output_offsets_feature_rank = torch.concat( - [ - zero_tensor.to(torch.int64), - output_sizes_feature_rank.flatten().cumsum(dim=0), - ] - ) - output_size = output_offsets_feature_rank[-1].item() - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - torch._check_is_size(output_size) - - # TODO: Support INT8 output - # B_offsets_rank_per_feature is for rank and (b, t) mapping - B_offsets_rank_per_feature = ( - torch.tensor( - [ - [0] + batch_size_per_feature - for batch_size_per_feature in batch_size_per_feature_per_rank - ], - device="cpu", - dtype=torch.int32, - ) - .cumsum(dim=1) - .to(torch.int) - ) - - B_offsets = B_offsets.to(self.current_device, non_blocking=True) - output_offsets_feature_rank = output_offsets_feature_rank.to( - self.current_device, non_blocking=True - ) - B_offsets_rank_per_feature = B_offsets_rank_per_feature.to( - self.current_device, non_blocking=True - ) - - # TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank - vbe_metadata = invokers.lookup_args.VBEMetadata( - B_offsets=B_offsets, - output_offsets_feature_rank=output_offsets_feature_rank, - B_offsets_rank_per_feature=B_offsets_rank_per_feature, - # pyre-ignore - max_B=max_B, - # pyre-ignore - max_B_feature_rank=max_B_feature_rank, - # pyre-ignore - output_size=output_size, - ) - else: - vbe_metadata = invokers.lookup_args.VBEMetadata( - B_offsets=None, - output_offsets_feature_rank=None, - B_offsets_rank_per_feature=None, - max_B=-1, - max_B_feature_rank=-1, - output_size=-1, - ) - return vbe_metadata + return generate_vbe_metadata( + offsets, + batch_size_per_feature_per_rank, + self.optimizer, + self.pooling_mode, + self.feature_dims, + self.current_device, + ) def forward( # noqa: C901 self, @@ -1571,8 +1591,7 @@ def prefetch( self.prefetch_stream != forward_stream ), "prefetch_stream and forward_stream should not be the same stream" vbe_metadata = self._generate_vbe_metadata( - offsets, - batch_size_per_feature_per_rank, + offsets, batch_size_per_feature_per_rank ) self._prefetch( indices, @@ -2473,8 +2492,9 @@ def __init__( ) T = len(feature_table_map) assert T_ <= T - D_offsets = [dims[t] for t in feature_table_map] - D_offsets = [0] + list(accumulate(D_offsets)) + + feature_dims = [dims[t] for t in feature_table_map] + D_offsets = [0] + list(accumulate(feature_dims)) self.total_D = D_offsets[-1] self.max_D = max(dims) self.register_buffer( @@ -2482,6 +2502,11 @@ def __init__( torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) assert self.D_offsets.numel() == T + 1 + # Required for VBE + self.register_buffer( + "feature_dims", + torch.tensor(feature_dims, device="cpu", dtype=torch.int64), + ) hash_size_cumsum = [0] + list(accumulate(rows)) if hash_size_cumsum[-1] == 0: @@ -2536,13 +2561,34 @@ def __init__( ), ) + @torch.jit.ignore + def _generate_vbe_metadata( + self, + offsets: Tensor, + batch_size_per_feature_per_rank: Optional[List[List[int]]], + ) -> invokers.lookup_args.VBEMetadata: + return generate_vbe_metadata( + offsets, + batch_size_per_feature_per_rank, + OptimType.NONE, + self.pooling_mode, + self.feature_dims, + self.current_device, + ) + def forward( self, indices: Tensor, offsets: Tensor, per_sample_weights: Optional[Tensor] = None, feature_requires_grad: Optional[Tensor] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> Tensor: + # Generate VBE metadata + vbe_metadata = self._generate_vbe_metadata( + offsets, batch_size_per_feature_per_rank + ) + (indices, offsets) = indices.long(), offsets.long() # Force casting per_sample_weights to float if per_sample_weights is not None: @@ -2562,6 +2608,12 @@ def forward( indice_weights=per_sample_weights, feature_requires_grad=feature_requires_grad, output_dtype=self.output_dtype, + B_offsets=vbe_metadata.B_offsets, + vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank, + vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature, + max_B=vbe_metadata.max_B, + max_B_feature_rank=vbe_metadata.max_B_feature_rank, + vbe_output_size=vbe_metadata.output_size, ) @torch.jit.export diff --git a/fbgemm_gpu/test/tbe/training/backward_dense_test.py b/fbgemm_gpu/test/tbe/training/backward_dense_test.py index 484d9a37bf..eeafeb925d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_dense_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_dense_test.py @@ -14,6 +14,10 @@ import hypothesis.strategies as st import numpy as np import torch +from deeplearning.fbgemm.fbgemm_gpu.test.tbe.common import ( + format_ref_tensors_in_mixed_B_layout, + gen_mixed_B_batch_sizes, +) from fbgemm_gpu.split_embedding_configs import SparseType from fbgemm_gpu.split_embedding_utils import ( b_indices, @@ -51,6 +55,7 @@ class BackwardDenseTest(unittest.TestCase): weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), weighted=st.booleans(), mixed=st.booleans(), + mixed_B=st.booleans(), long_segments=st.booleans(), pooling_mode=st.sampled_from( [ @@ -80,6 +85,7 @@ def test_backward_dense( # noqa C901 weights_precision: SparseType, weighted: bool, mixed: bool, + mixed_B: bool, long_segments: bool, pooling_mode: PoolingMode, use_cpu: bool, @@ -94,6 +100,7 @@ def test_backward_dense( # noqa C901 assume(not use_cpu or pooling_mode != PoolingMode.NONE) assume(not mixed or pooling_mode != PoolingMode.NONE) assume(not weighted or pooling_mode != PoolingMode.NONE) + assume(not mixed_B or (not use_cpu and pooling_mode != PoolingMode.NONE)) emb_op = DenseTableBatchedEmbeddingBagsCodegen if pooling_mode == PoolingMode.SUM: @@ -139,22 +146,30 @@ def test_backward_dense( # noqa C901 if weights_precision == SparseType.FP16: bs = [b.half() for b in bs] + feature_table_map = list(range(T)) + num_features = len(feature_table_map) + if not mixed_B: + Bs = [B] * num_features + Bs_rank_feature = [[0]] + else: + Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, num_features) + xs = [ to_device( torch.from_numpy( - np.random.choice(range(e), size=(B, L), replace=True).astype( + np.random.choice(range(Es[t]), size=(b, L), replace=True).astype( np.int64 ) ), use_cpu, ) - for e in Es + for t, b in zip(feature_table_map, Bs) ] if long_segments and L > 0 and weights_precision != SparseType.FP16: for x in xs: x[:, 0] = 0 - xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)] + xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] if weights_precision == SparseType.FP16: xws = [xw.half() for xw in xws] @@ -198,23 +213,39 @@ def test_backward_dense( # noqa C901 ) if do_pooling: # NOTE: test TorchScript-compatible! - cc = torch.jit.script(cc) + torch.jit.script(cc) for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) - x = torch.cat([x.view(1, B, L) for x in xs], dim=0) - xw = torch.cat([xw.view(1, B, L) for xw in xws], dim=0) + x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) + xw = torch.cat([xw.contiguous().flatten() for xw in xws], dim=0) + + (indices, offsets) = get_table_batched_offsets_from_dense( + x, L, sum(Bs), use_cpu=use_cpu + ) + batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None - (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=use_cpu) fc2 = ( - cc(indices, offsets) + cc( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) if not weighted - else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)) + else cc( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) ) if do_pooling: - f = torch.cat([f.view(B, -1) for f in fs], dim=1) + if mixed_B: + f = format_ref_tensors_in_mixed_B_layout(fs, Bs_rank_feature) + else: + f = torch.cat([f.view(B, -1) for f in fs], dim=1) else: f = torch.cat(fs, dim=0).view(-1, D) @@ -231,7 +262,10 @@ def test_backward_dense( # noqa C901 rtol=tol, ) if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1) + if mixed_B: + goc = format_ref_tensors_in_mixed_B_layout(gos, Bs_rank_feature) + else: + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: goc = torch.cat(gos, dim=0) fc2.backward(goc) @@ -256,7 +290,12 @@ def test_backward_dense( # noqa C901 offsets.requires_grad = False for param in cc.parameters(): param.requires_grad = False - y = cc(indices, offsets, per_sample_weights) + y = cc( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) y.sum().backward() # pyre-fixme[16]: `Optional` has no attribute `clone`. indice_weight_grad_all = per_sample_weights.grad.clone().cpu() @@ -272,10 +311,12 @@ def test_backward_dense( # noqa C901 offsets, per_sample_weights, feature_requires_grad=feature_requires_grad, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, ) y.sum().backward() indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() for t in range(T_): + B = Bs[t] if feature_requires_grad[t]: torch.testing.assert_close( indice_weight_grad_mask.view(T_, B, L)[t], @@ -296,7 +337,17 @@ def test_backward_dense( # noqa C901 for param in cc.parameters(): param.requires_grad = False gradcheck( - cc, (indices, offsets, per_sample_weights), eps=1e-2, atol=1e-3, rtol=1e-3 + cc, + ( + indices, + offsets, + per_sample_weights, + None, + batch_size_per_feature_per_rank, + ), + eps=1e-2, + atol=1e-3, + rtol=1e-3, )