Skip to content

Commit

Permalink
Add KT.regroup to kT pytree unflatten function to preserve the correc…
Browse files Browse the repository at this point in the history
…t key order

Summary:
# context
* this diff basically re-order the keys in a KT every time after the flatten/unflatten step, based on the key order in the exported IR
* This diff is the short-term solution regarding to [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/)
 {F1865696002}

Reviewed By: PaulZhang12

Differential Revision: D59238744
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 13, 2024
1 parent b1c81a2 commit 659e6d4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
10 changes: 9 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,15 @@ def _kt_unflatten(


def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]:
return _kt_flatten(kt)[0]
_keys, _length_per_key = spec.context
# please read https://fburl.com/workplace/8bei5iju for more context,
# you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict
logger.warning(
"KT's key order might change from spec from the torch.export, this could have perf impact. "
f"{kt.keys()} vs {_keys}"
)
res = permute_multi_embedding([kt], [_keys])
return [res[0]]


# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static
Expand Down
33 changes: 29 additions & 4 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.utils._pytree as pytree
from torch.fx._pytree import tree_flatten_spec
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
Expand Down Expand Up @@ -2691,21 +2692,45 @@ def test_string_values(self) -> None:

def test_pytree(self) -> None:
tensor_list = [
torch.Tensor([[1.0, 1.0]]),
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]),
torch.Tensor([[1.0, 1.0]]).T,
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T,
]
keys = ["dense_0", "dense_1"]
kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0)

kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1)
# generate the out_spec in the torch.export run
flattened, out_spec = pytree.tree_flatten(kt)

# first element of flattened list should be the kt._values
self.assertTrue(torch.equal(flattened[0], kt.values()))
# re-construct the unflattened kt from the flattened list plus the out_spec
unflattened = pytree.tree_unflatten(flattened, out_spec)

self.assertTrue(isinstance(unflattened, KeyedTensor))
self.assertListEqual(unflattened.keys(), keys)
self.assertListEqual(unflattened._length_per_key, kt._length_per_key)

# for ir export, key order in KT could change
tensor_list = [
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T,
torch.Tensor([[1.0, 1.0]]).T,
]
keys = ["dense_1", "dense_0"]
kt2 = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1)

# flatten the kt2 based on previously generated out_spec
# this is to mimic the exported_program module run
# the kt2 could have different key order but out_spec is the same
flattened2 = tree_flatten_spec(kt2, out_spec)

# re-construct the unflattened kt from the flattened list plus the out_spec
# the rebuilt kt2 should contain the same effective data as kt (ignoring key order)
unflattened2 = pytree.tree_unflatten(flattened2, out_spec)
self.assertTrue(isinstance(unflattened2, KeyedTensor))
self.assertSetEqual(set(unflattened.keys()), set(unflattened2.keys()))
for key in kt.keys():
torch.testing.assert_close(unflattened[key], unflattened2[key])
torch.testing.assert_close(kt[key], unflattened2[key])


class TestKeyedTensorRegroupOp(unittest.TestCase):
@repeat_test(device_str=["cpu", "meta", "cuda"])
Expand Down

0 comments on commit 659e6d4

Please sign in to comment.