diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 7e71acb19..84ac0e1a7 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1201,62 +1201,59 @@ def _maybe_compute_kjt_to_jt_dict( if not length_per_key: return {} - if jt_dict is None: - _jt_dict: Dict[str, JaggedTensor] = {} - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - cat_size = 0 - total_size = values.size(0) - for i in length_per_key: - cat_size += i - torch._check(cat_size <= total_size) - torch._check(cat_size == total_size) - values_list = torch.split(values, length_per_key) - if variable_stride_per_key: - split_lengths = torch.split(lengths, stride_per_key) - split_offsets = [ - torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - for lengths in split_lengths - ] - else: - split_lengths = torch.unbind( - ( - lengths.view(-1, stride) - if pt2_guard_size_oblivious(lengths.numel() != 0) - else lengths - ), - dim=0, + if jt_dict is not None: + return jt_dict + + _jt_dict: Dict[str, JaggedTensor] = {} + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + cat_size = 0 + total_size = values.size(0) + for i in length_per_key: + cat_size += i + torch._check(cat_size <= total_size) + torch._check(cat_size == total_size) + values_list = torch.split(values, length_per_key) + if variable_stride_per_key: + split_lengths = torch.split(lengths, stride_per_key) + split_offsets = [ + torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + for lengths in split_lengths + ] + elif pt2_guard_size_oblivious(lengths.numel() > 0): + strided_lengths = lengths.view(len(keys), stride) + split_lengths = torch.unbind( + strided_lengths, + dim=0, + ) + split_offsets = torch.unbind( + _batched_lengths_to_offsets(strided_lengths), + dim=0, + ) + else: + split_lengths = torch.unbind(lengths, dim=0) + split_offsets = torch.unbind(lengths, dim=0) + + if weights is not None: + weights_list = torch.split(weights, length_per_key) + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + weights=weights_list[idx], ) - split_offsets = torch.unbind( - ( - _batched_lengths_to_offsets(lengths.view(-1, stride)) - if pt2_guard_size_oblivious(lengths.numel() != 0) - else lengths - ), - dim=0, + else: + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], ) - - if weights is not None: - weights_list = torch.split(weights, length_per_key) - for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - weights=weights_list[idx], - ) - else: - for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - ) - jt_dict = _jt_dict - return jt_dict + return _jt_dict @torch.fx.wrap