diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 9125d1979..7dbe84921 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2943,7 +2943,9 @@ def _kt_unflatten( def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: - return _kt_flatten(kt)[0] + _keys, _length_per_key = spec.context + res = KeyedTensor.regroup([kt], [_keys]) + return [res[0]] # The assumption here in torch.exporting KeyedTensor is that _length_per_key is static