Skip to content

Commit

Permalink
Reset init state of KTRegroupAsDict after sharding inference model (p…
Browse files Browse the repository at this point in the history
…ytorch#2337)

Summary:
Pull Request resolved: pytorch#2337

KTRegroupAsDict was created as a module to cache certain computations consistent across batches for regrouping KeyedTensors into pooled embeddings for Ads models.

However, there existed a bug that caused NE regression for TorchRec inference use cases. The order of KeyedTensors into the regroup module can change after sharding, which would make the previous caching invalid, resulting in an NE issue. D61615045 was a temporary fix, but this diff ensures that this module is reset during the official TorchRec API for sharding inference models

Reviewed By: dstaay-fb, ZhengkaiZ

Differential Revision: D61703144

fbshipit-source-id: 10a13089654e426cecdc4939101cef8668286449
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Aug 28, 2024
1 parent 6c8b397 commit b082638
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 1 deletion.
5 changes: 5 additions & 0 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ShardingPlan,
)
from torchrec.distributed.utils import init_parameters
from torchrec.types import CacheMixin


def _join_module_path(path: str, name: str) -> str:
Expand Down Expand Up @@ -279,4 +280,8 @@ def _replace(_model: nn.Module, path: str = "") -> None:
init_parameters(module, device)
module = module.to(device)

for submod in module.modules():
if isinstance(submod, CacheMixin):
submod.clear_cache()

return module
64 changes: 64 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection
from torchrec.modules.feature_processor import PositionWeightedProcessor
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Pipelineable

Expand Down Expand Up @@ -575,6 +576,69 @@ def _concat(
return torch.cat([dense] + sparse_embeddings, dim=1)


class TestOverArchRegroupModule(nn.Module):
"""
Basic nn.Module for testing
Args:
device
Call Args:
dense: torch.Tensor,
sparse: KeyedTensor,
Returns:
torch.Tensor
Example::
TestOverArch()
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
embedding_names: Optional[List[str]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
if device is None:
device = torch.device("cpu")
self._embedding_names: List[str] = (
embedding_names
if embedding_names
else [feature for table in tables for feature in table.feature_names]
)
self._weighted_features: List[str] = [
feature for table in weighted_tables for feature in table.feature_names
]
in_features = (
8
+ sum([table.embedding_dim * len(table.feature_names) for table in tables])
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in weighted_tables
]
)
)
self.dhn_arch: nn.Module = TestDHNArch(in_features, device)
self.regroup_module = KTRegroupAsDict(
[self._embedding_names, self._weighted_features],
["unweighted", "weighted"],
)

def forward(
self,
dense: torch.Tensor,
sparse: KeyedTensor,
) -> torch.Tensor:
pooled_emb = self.regroup_module([sparse])
values = list(pooled_emb.values())
return self.dhn_arch(_concat(dense, values))


class TestOverArch(nn.Module):
"""
Basic nn.Module for testing
Expand Down
33 changes: 33 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import torchrec as trec
import torchrec.distributed as trec_dist
import torchrec.quant as trec_quant
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch.fx.passes.split_utils import getattr_recursive
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
Expand Down Expand Up @@ -555,3 +558,33 @@ def shard_quant_model(
)

return model, model_plan


def get_table_to_weights_from_tbe(
model: torch.nn.Module,
) -> Dict[str, List[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
table_to_weight = {}

for module in model.modules():
if isinstance(module, IntNBitTableBatchedEmbeddingBagsCodegen):
weights = module.split_embedding_weights()
for i, spec in enumerate(module.embedding_specs):
table_to_weight[spec[0]] = weights[i]

return table_to_weight


def assign_weights_to_tbe(
model: torch.nn.Module,
table_to_weight: Dict[str, List[Tuple[torch.Tensor, Optional[torch.Tensor]]]],
) -> None:
for module in model.modules():
if isinstance(module, IntNBitTableBatchedEmbeddingBagsCodegen):
q_weights = []
for spec in module.embedding_specs:
assert spec[0] in table_to_weight, f"{spec[0]} not in table_to_weight"
q_weights.append(table_to_weight[spec[0]])

module.assign_embedding_weights(q_weights)

return
79 changes: 79 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,53 @@
import unittest
from argparse import Namespace

import torch
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.test_utils.test_model import (
ModelInput,
TestOverArchRegroupModule,
TestSparseNN,
)

from torchrec.inference.dlrm_predict import (
create_training_batch,
DLRMModelConfig,
DLRMPredictFactory,
)
from torchrec.inference.modules import (
assign_weights_to_tbe,
get_table_to_weights_from_tbe,
quantize_inference_model,
shard_quant_model,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig


class InferenceTest(unittest.TestCase):
def setUp(self) -> None:
num_features = 4
num_weighted_features = 2

self.tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_features)
]
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(num_weighted_features)
]

def test_dlrm_inference_package(self) -> None:
args = Namespace()
args.batch_size = 10
Expand Down Expand Up @@ -55,3 +92,45 @@ def test_dlrm_inference_package(self) -> None:
DLRMPredictFactory(model_config).create_predict_module(
world_size=1, device="cpu"
)

def test_regroup_module_inference(self) -> None:
set_propogate_device(True)
model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=torch.device("cpu"),
sparse_device=torch.device("cpu"),
over_arch_clazz=TestOverArchRegroupModule,
)

model.eval()
_, local_batch = ModelInput.generate(
batch_size=16,
world_size=1,
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
)

with torch.inference_mode():
output = model(local_batch[0])

# Quantize the model and collect quantized weights
quantized_model = quantize_inference_model(model)
quantized_output = quantized_model(local_batch[0])
table_to_weight = get_table_to_weights_from_tbe(quantized_model)

# Shard the model, all weights are initialized back to 0, so have to reassign weights
sharded_quant_model, _ = shard_quant_model(
quantized_model,
world_size=2,
compute_device="cpu",
sharding_device="cpu",
)
assign_weights_to_tbe(quantized_model, table_to_weight)

sharded_quant_output = sharded_quant_model(local_batch[0])

self.assertTrue(torch.allclose(output, quantized_output, atol=1e-4))
self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-4))
6 changes: 5 additions & 1 deletion torchrec/modules/regroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_kt_regroup_arguments,
KeyedTensor,
)
from torchrec.types import CacheMixin


@torch.fx.wrap
Expand Down Expand Up @@ -114,7 +115,7 @@ def forward(self, values: List[torch.Tensor]) -> List[torch.Tensor]:
)


class KTRegroupAsDict(torch.nn.Module):
class KTRegroupAsDict(torch.nn.Module, CacheMixin):
"""
KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict()
Expand Down Expand Up @@ -208,3 +209,6 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
keyed_tensors, self._idx_key_pairs, self._dim
)
return _build_dict(self._keys, permuted_values, self._splits, self._dim)

def clear_cache(self) -> None:
self._is_inited = False
9 changes: 9 additions & 0 deletions torchrec/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
from torch import nn


class CacheMixin:
"""
A mixin to allow modules that cache computation to clear the cache.
"""

@abstractmethod
def clear_cache(self) -> None: ...


class CopyMixIn:
@abstractmethod
def copy(self, device: torch.device) -> nn.Module: ...
Expand Down

0 comments on commit b082638

Please sign in to comment.