diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 6322a3959..2eaceb192 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -9,13 +9,14 @@ #!/usr/bin/env python3 +import timeit import unittest from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple +import click import fbgemm_gpu.sparse_ops # noqa: F401, E402 - import torch import torchrec import torchrec.pt2.checks @@ -504,23 +505,22 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: def _test_compile_fake_pg_fn( rank: int, world_size: int, + num_features: int = 5, + batch_size: int = 10, + num_embeddings: int = 256, ) -> None: sharding_type = ShardingType.TABLE_WISE.value - input_type = _InputType.SINGLE_BATCH torch_compile_backend = "eager" config = _TestConfig() - num_embeddings = 256 # emb_dim must be % 4 == 0 for fbgemm emb_dim = 12 - batch_size = 10 - num_features: int = 5 num_float_features: int = 8 num_weighted_features: int = 1 device: torch.Device = torch.device("cuda") store = FakeStore() - dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + dist.init_process_group(backend="fake", rank=rank, world_size=2, store=store) pg: ProcessGroup = dist.distributed_c10d._get_default_group() topology: Topology = Topology(world_size=world_size, compute_device="cuda") @@ -601,29 +601,17 @@ def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore ins = [] for _ in range(1 + n_extra_numerics_checks): - if input_type == _InputType.VARIABLE_BATCH: - ( - _, - local_model_inputs, - ) = ModelInput.generate_variable_batch_input( - average_batch_size=batch_size, - world_size=world_size, - num_float_features=num_float_features, - # pyre-ignore - tables=mi.tables, - ) - else: - ( - _, - local_model_inputs, - ) = ModelInput.generate( - batch_size=batch_size, - world_size=world_size, - num_float_features=num_float_features, - tables=mi.tables, - weighted_tables=mi.weighted_tables, - variable_batch_size=False, - ) + ( + _, + local_model_inputs, + ) = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=mi.tables, + weighted_tables=mi.weighted_tables, + variable_batch_size=False, + ) ins.append(local_model_inputs) local_model_input = ins[0][rank].to(device) @@ -655,9 +643,6 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: eager_out = dmp(kjt_ft, ff) reduce_to_scalar_loss(eager_out).backward() - if torch_compile_backend is None: - return - ##### COMPILE ##### with unittest.mock.patch( "torch._dynamo.config.skip_torchrec", @@ -773,3 +758,54 @@ def test_compile_multiprocess_fake_pg( rank=0, world_size=2, ) + + +@click.command() +@click.option( + "--repeat", + type=int, + default=1, + help="repeat times", +) +@click.option( + "--rank", + type=int, + default=0, + help="rank in the test", +) +@click.option( + "--world-size", + type=int, + default=2, + help="world_size in the test", +) +@click.option( + "--num-features", + type=int, + default=5, + help="num_features in the test", +) +@click.option( + "--batch-size", + type=int, + default=10, + help="batch_size in the test", +) +def compile_benchmark( + rank: int, world_size: int, num_features: int, batch_size: int, repeat: int +) -> None: + run: str = ( + f"_test_compile_fake_pg_fn(rank={rank}, world_size={world_size}, " + f"num_features={num_features}, batch_size={batch_size})" + ) + print("*" * 20 + " compile_benchmark started " + "*" * 20) + t = timeit.timeit(stmt=run, number=repeat, globals=globals()) + print("*" * 20 + " compile_benchmark completed " + "*" * 20) + print( + f"rank: {rank}, world_size: {world_size}, " + f"num_features: {num_features}, batch_size: {batch_size}, time: {t:.2f}s" + ) + + +if __name__ == "__main__": + compile_benchmark()