From 1ce8f5a0923dab71ee53fc49aa7764710cc8e23f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 22 Aug 2024 17:43:23 -0700 Subject: [PATCH] refactor _maybe_compute_kjt_to_jt_dict (#2326) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2326 # context * want to resolve graph break: [failures_and_restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpKJM3FI/failures_and_restarts.html), P1537573230 ``` Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression Eq(((2*u48)//(u48 + u49)), 0) (unhinted: Eq(((2*u48)//(u48 + u49)), 0)). (Size-like symbols: u49, u48) Potential framework code culprit (scroll up for full backtrace): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u49,u48" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing User Stack (most recent call last): (snipped, see stack below for prefix) ... File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 2241, in to_dict _jt_dict = _maybe_compute_kjt_to_jt_dict( File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1226, in _maybe_compute_kjt_to_jt_dict split_lengths = torch.unbind( ``` * we added [shape check](https://fburl.com/code/p02u4mck): ``` if pt2_guard_size_oblivious(lengths.numel() > 0): strided_lengths = lengths.view(-1, stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.shape[0] > 0) torch._check(strided_lengths.shape[1] > 0) split_lengths = torch.unbind( strided_lengths, dim=0, ) ``` * however the error is still there ``` File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/symbolic_shapes.py", line 253, in guard_size_oblivious return expr.node.guard_size_oblivious("", 0) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/sym_node.py", line 503, in guard_size_oblivious r = self.shape_env.evaluate_expr( ``` * [implementation](https://fburl.com/code/20iue1ib) ``` register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: from torch.fx.experimental.symbolic_shapes import guard_size_oblivious dim = utils.canonicalize_dim(t.ndim, dim) torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", ) if guard_size_oblivious(t.shape[dim] == 0): # <------- here return () else: return tuple( torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) ) ``` * with D61677207 [no graph break at _maybe_compute_kjt_to_jt_dict](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNcI14t/failures_and_restarts.html) Reviewed By: IvanKobzarev Differential Revision: D55277785 --- torchrec/sparse/jagged_tensor.py | 108 ++++++++++++++++--------------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 7e71acb19..7b2e6cbdb 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -20,6 +20,7 @@ from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node from torchrec.pt2.checks import ( is_non_strict_exporting, + is_pt2_compiling, is_torchdynamo_compiling, pt2_check_size_nonzero, pt2_checks_all_is_size, @@ -1201,62 +1202,63 @@ def _maybe_compute_kjt_to_jt_dict( if not length_per_key: return {} - if jt_dict is None: - _jt_dict: Dict[str, JaggedTensor] = {} + if jt_dict is not None: + return jt_dict + + _jt_dict: Dict[str, JaggedTensor] = {} + if not torch.jit.is_scripting() and is_pt2_compiling(): + cat_size = 0 + total_size = values.size(0) + for i in length_per_key: + cat_size += i + torch._check(cat_size <= total_size) + torch._check(cat_size == total_size) + torch._check_is_size(stride) + values_list = torch.split(values, length_per_key) + if variable_stride_per_key: + split_lengths = torch.split(lengths, stride_per_key) + split_offsets = [ + torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + for lengths in split_lengths + ] + elif pt2_guard_size_oblivious(lengths.numel() > 0): + strided_lengths = lengths.view(len(keys), stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - cat_size = 0 - total_size = values.size(0) - for i in length_per_key: - cat_size += i - torch._check(cat_size <= total_size) - torch._check(cat_size == total_size) - values_list = torch.split(values, length_per_key) - if variable_stride_per_key: - split_lengths = torch.split(lengths, stride_per_key) - split_offsets = [ - torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - for lengths in split_lengths - ] - else: - split_lengths = torch.unbind( - ( - lengths.view(-1, stride) - if pt2_guard_size_oblivious(lengths.numel() != 0) - else lengths - ), - dim=0, + torch._check(strided_lengths.size(0) > 0) + torch._check(strided_lengths.size(1) > 0) + split_lengths = torch.unbind( + strided_lengths, + dim=0, + ) + split_offsets = torch.unbind( + _batched_lengths_to_offsets(strided_lengths), + dim=0, + ) + else: + split_lengths = torch.unbind(lengths, dim=0) + split_offsets = torch.unbind(lengths, dim=0) + + if weights is not None: + weights_list = torch.split(weights, length_per_key) + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + weights=weights_list[idx], ) - split_offsets = torch.unbind( - ( - _batched_lengths_to_offsets(lengths.view(-1, stride)) - if pt2_guard_size_oblivious(lengths.numel() != 0) - else lengths - ), - dim=0, + else: + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], ) - - if weights is not None: - weights_list = torch.split(weights, length_per_key) - for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - weights=weights_list[idx], - ) - else: - for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - ) - jt_dict = _jt_dict - return jt_dict + return _jt_dict @torch.fx.wrap