Skip to content

Commit

Permalink
refactor _maybe_compute_kjt_to_jt_dict (pytorch#2326)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2326

# context
* want to resolve graph break: [failures_and_restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpKJM3FI/failures_and_restarts.html), P1537573230
```
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs.  Could not guard on data-dependent expression Eq(((2*u48)//(u48 + u49)), 0) (unhinted: Eq(((2*u48)//(u48 + u49)), 0)).  (Size-like symbols: u49, u48)

Potential framework code culprit (scroll up for full backtrace):
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind
    if guard_size_oblivious(t.shape[dim] == 0):

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u49,u48"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
...
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 2241, in to_dict
    _jt_dict = _maybe_compute_kjt_to_jt_dict(
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1226, in _maybe_compute_kjt_to_jt_dict
    split_lengths = torch.unbind(
```
* we added [shape check](https://fburl.com/code/p02u4mck):
```
        if pt2_guard_size_oblivious(lengths.numel() > 0):
            strided_lengths = lengths.view(-1, stride)
            if not torch.jit.is_scripting() and is_torchdynamo_compiling():
                torch._check(strided_lengths.shape[0] > 0)
                torch._check(strided_lengths.shape[1] > 0)
            split_lengths = torch.unbind(
                strided_lengths,
                dim=0,
            )
```
* however the error is still there
```
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind
    if guard_size_oblivious(t.shape[dim] == 0):
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/symbolic_shapes.py", line 253, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/sym_node.py", line 503, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
```
* [implementation](https://fburl.com/code/20iue1ib)
```
register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    dim = utils.canonicalize_dim(t.ndim, dim)
    torch._check_index(
        len(t.shape) > 0,
        lambda: "Dimension specified as 0 but tensor has no dimensions",
    )
    if guard_size_oblivious(t.shape[dim] == 0):  # <------- here
        return ()
    else:
        return tuple(
            torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
        )
```
* with D61677207 [no graph break at _maybe_compute_kjt_to_jt_dict](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNcI14t/failures_and_restarts.html)

Reviewed By: IvanKobzarev

Differential Revision: D55277785
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 23, 2024
1 parent 6f0ea08 commit 1ce8f5a
Showing 1 changed file with 55 additions and 53 deletions.
108 changes: 55 additions & 53 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
from torchrec.pt2.checks import (
is_non_strict_exporting,
is_pt2_compiling,
is_torchdynamo_compiling,
pt2_check_size_nonzero,
pt2_checks_all_is_size,
Expand Down Expand Up @@ -1201,62 +1202,63 @@ def _maybe_compute_kjt_to_jt_dict(
if not length_per_key:
return {}

if jt_dict is None:
_jt_dict: Dict[str, JaggedTensor] = {}
if jt_dict is not None:
return jt_dict

_jt_dict: Dict[str, JaggedTensor] = {}
if not torch.jit.is_scripting() and is_pt2_compiling():
cat_size = 0
total_size = values.size(0)
for i in length_per_key:
cat_size += i
torch._check(cat_size <= total_size)
torch._check(cat_size == total_size)
torch._check_is_size(stride)
values_list = torch.split(values, length_per_key)
if variable_stride_per_key:
split_lengths = torch.split(lengths, stride_per_key)
split_offsets = [
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
for lengths in split_lengths
]
elif pt2_guard_size_oblivious(lengths.numel() > 0):
strided_lengths = lengths.view(len(keys), stride)
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
cat_size = 0
total_size = values.size(0)
for i in length_per_key:
cat_size += i
torch._check(cat_size <= total_size)
torch._check(cat_size == total_size)
values_list = torch.split(values, length_per_key)
if variable_stride_per_key:
split_lengths = torch.split(lengths, stride_per_key)
split_offsets = [
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
for lengths in split_lengths
]
else:
split_lengths = torch.unbind(
(
lengths.view(-1, stride)
if pt2_guard_size_oblivious(lengths.numel() != 0)
else lengths
),
dim=0,
torch._check(strided_lengths.size(0) > 0)
torch._check(strided_lengths.size(1) > 0)
split_lengths = torch.unbind(
strided_lengths,
dim=0,
)
split_offsets = torch.unbind(
_batched_lengths_to_offsets(strided_lengths),
dim=0,
)
else:
split_lengths = torch.unbind(lengths, dim=0)
split_offsets = torch.unbind(lengths, dim=0)

if weights is not None:
weights_list = torch.split(weights, length_per_key)
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
weights=weights_list[idx],
)
split_offsets = torch.unbind(
(
_batched_lengths_to_offsets(lengths.view(-1, stride))
if pt2_guard_size_oblivious(lengths.numel() != 0)
else lengths
),
dim=0,
else:
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
)

if weights is not None:
weights_list = torch.split(weights, length_per_key)
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
weights=weights_list[idx],
)
else:
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
)
jt_dict = _jt_dict
return jt_dict
return _jt_dict


@torch.fx.wrap
Expand Down

0 comments on commit 1ce8f5a

Please sign in to comment.