Skip to content

Commit

Permalink
Test case for EBC key-order change
Browse files Browse the repository at this point in the history
Summary:
# context
* [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/)
* this test case mimics the EBC key-order change after sharding
 {F1864056306}

# details
* it's a very simple model: EBC ---> KTRegroupAsDict
* we generate two EBCs: ebc1 and ebc2, such that the table orders are different:
```
        ebc1 = EmbeddingBagCollection(
            tables=[tb1_config, tb2_config, tb3_config],
            is_weighted=False,
        )
        ebc2 = EmbeddingBagCollection(
            tables=[tb1_config, tb3_config, tb2_config],
            is_weighted=False,
        )
```
* we export the model with ebc1 and unflatten the model, and then swap with ebc2 (you can think this as the the sharding process resulting a shardedEBC), so that we can mimic the key-order change as shown in the above graph
* the test checks the final results after KTRegroupAsDict are consistent with the original eager model

Reviewed By: PaulZhang12

Differential Revision: D62604419
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 12, 2024
1 parent 75982ff commit c6d3d9e
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchrec.ir.serializer import JsonSerializer

from torchrec.ir.utils import (
# bypass_pytree_ebc_regroup,
decapsulate_ir_modules,
encapsulate_ir_modules,
mark_dynamic_kjt,
Expand Down Expand Up @@ -521,3 +522,69 @@ def forward(
self.assertTrue(deserialized_model.regroup._is_inited)
for key in eager_out.keys():
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)

def test_key_order_with_ebc_and_regroup(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3", "f4", "f5"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
)
ebc1 = EmbeddingBagCollection(
tables=[tb1_config, tb2_config, tb3_config],
is_weighted=False,
)
ebc2 = EmbeddingBagCollection(
tables=[tb1_config, tb3_config, tb2_config],
is_weighted=False,
)
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])
class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.ebc = ebc
self.regroup = regroup

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model.ebc = ebc2
# bypass_pytree_ebc_regroup(deserialized_model)
deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
torch.testing.assert_close(deserialized_out[key], eager_out[key])

0 comments on commit c6d3d9e

Please sign in to comment.