diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index bc499f335..511ee00be 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -643,6 +643,127 @@ def trace_handler(prof) -> None: ) +def benchmark_func( + name: str, + bench_inputs: List[Dict[str, Any]], + prof_inputs: List[Dict[str, Any]], + world_size: int, + profile_dir: str, + num_benchmarks: int, + num_profiles: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int, + device_type: str = "cuda", +) -> BenchmarkResult: + max_mem_allocated: List[int] = [] + if device_type == "cuda": + if rank == -1: + # Reset memory for measurement, no process per rank so do all + for di in range(world_size): + torch.cuda.reset_peak_memory_stats(di) + else: + torch.cuda.reset_peak_memory_stats(rank) + + start = [] + end = [] + if device_type == "cuda": + # Measure time taken for batches in bench_inputs + start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + + if benchmark_func_kwargs is None: + # Need this to unwrap + benchmark_func_kwargs = {} + + times = [] + if device_type == "cuda": + for i in range(num_benchmarks): + start[i].record() + func_to_benchmark(bench_inputs, **benchmark_func_kwargs) + end[i].record() + elif device_type == "cpu": + times = timeit.repeat( + lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) + + if device_type == "cuda": + if rank == -1: + for di in range(world_size): + torch.cuda.synchronize(di) + else: + torch.cuda.synchronize(rank) + + # TODO: First Benchmark Run for Eager Mode produces outlier + # Start counting after first as workaround for standard deviation + if device_type == "cuda": + elapsed_time = torch.tensor( + [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + ) + else: + elapsed_time = torch.tensor(times) * 1e3 + + if device_type == "cuda": + if rank == -1: + # Add up all memory allocated in inference mode + for di in range(world_size): + b = torch.cuda.max_memory_allocated(di) + max_mem_allocated.append(b // 1024 // 1024) + else: + # Only add up memory allocated for current rank in training mode + b = torch.cuda.max_memory_allocated(rank) + max_mem_allocated.append(b // 1024 // 1024) + + if profile_dir != "": + # Only do profiling if output_dir is set + + # pyre-ignore[2] + def trace_handler(prof) -> None: + total_average = prof.profiler.total_average() + logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") + dir_path: str = profile_dir + if rank == 0: + trace_file: str = f"{dir_path}/trace-{name}.json" + else: + trace_file: str = f"{dir_path}/trace-{name}-{rank}.json" + return # only 1 rank should output in pg case, rank = 0 + logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") + prof.export_chrome_trace(trace_file) + + if device_type == "cuda": + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_flops=True, + with_modules=True, + on_trace_ready=trace_handler, + ) as p: + for i in range(num_profiles): + with record_function(f"## profile {i} ##"): + func_to_benchmark(prof_inputs, **benchmark_func_kwargs) + p.step() + + if rank == -1: + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + else: + torch.cuda.synchronize() + + return BenchmarkResult( + short_name=name, + elapsed_time=elapsed_time, + max_mem_allocated=max_mem_allocated, + rank=rank, + ) + + def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: if sharding_type == ShardingType.TABLE_WISE: name = "tw-sharded" diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index f92cf82c9..35d7c774a 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -12,7 +12,7 @@ import copy import multiprocessing import os -from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import click @@ -22,7 +22,7 @@ from torch import nn, optim from torch.optim import Optimizer from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.benchmark.benchmark_utils import benchmark +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.test_utils.multi_process import MultiProcessContext from torchrec.distributed.test_utils.test_model import ( @@ -46,6 +46,23 @@ from torchrec.test_utils import get_free_port +_pipeline_cls: Dict[str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]] = { + "base": TrainPipelineBase, + "sparse": TrainPipelineSparseDist, + "semi": TrainPipelineSemiSync, + "prefetch": PrefetchTrainPipelineSparseDist, +} + + +def _gen_pipelines( + pipelines: str, +) -> List[Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]]: + if pipelines == "all": + return list(_pipeline_cls.values()) + else: + return [_pipeline_cls[pipelines]] + + @click.command() @click.option( "--world_size", @@ -82,6 +99,30 @@ default=100, help="Pooling Factor.", ) +@click.option( + "--input_type", + type=str, + default="kjt", + help="Input type: kjt, td", +) +@click.option( + "--pipeline", + type=str, + default="all", + help="Pipeline to run: all, base, sparse, semi, prefetch", +) +@click.option( + "--multi_process", + type=bool, + default=True, + help="Run in multi process mode.", +) +@click.option( + "--profile", + type=str, + default="", + help="profile output directory", +) def main( world_size: int, n_features: int, @@ -89,6 +130,10 @@ def main( n_batches: int, batch_size: int, pooling_factor: int, + input_type: str, + pipeline: str, + multi_process: bool, + profile: str, ) -> None: """ Checks that pipelined training is equivalent to non-pipelined training. @@ -125,18 +170,34 @@ def main( batch_size=batch_size, world_size=world_size, pooling_factor=pooling_factor, + input_type=input_type, ) - _run_multi_process_test( - callable=runner, - tables=tables, - weighted_tables=weighted_tables, - sharding_type=ShardingType.TABLE_WISE.value, - kernel_type=EmbeddingComputeKernel.FUSED.value, - batches=batches, - fused_params={}, - world_size=world_size, - ) + if multi_process: + _run_multi_process_test( + callable=runner, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + batches=batches, + fused_params={}, + world_size=world_size, + pipelines=pipeline, + profile=profile, + ) + else: + single_runner( + tables=tables, + weighted_tables=weighted_tables, + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + batches=batches, + fused_params={}, + world_size=1, + pipelines=pipeline, + profile=profile, + ) def _run_multi_process_test( @@ -179,6 +240,7 @@ def _generate_data( batch_size: int = 4096, world_size: int = 1, pooling_factor: int = 10, + input_type: str = "kjt", ) -> List[List[ModelInput]]: return [ ModelInput.generate( @@ -238,6 +300,8 @@ def runner( fused_params: Dict[str, Any], world_size: int, batches: List[List[ModelInput]], + pipelines: str, + profile: str, ) -> None: torch.autograd.set_detect_anomaly(True) @@ -269,12 +333,7 @@ def runner( }, ) bench_inputs = [batch[rank] for batch in batches] - for pipeline_clazz in [ - TrainPipelineBase, - TrainPipelineSparseDist, - TrainPipelineSemiSync, - PrefetchTrainPipelineSparseDist, - ]: + for pipeline_clazz in _gen_pipelines(pipelines=pipelines): if pipeline_clazz == TrainPipelineSemiSync: # pyre-ignore [28] pipeline = pipeline_clazz( @@ -292,8 +351,8 @@ def runner( pipeline.progress(iter(bench_inputs)) def _func_to_benchmark( - model: nn.Module, bench_inputs: List[ModelInput], + model: nn.Module, pipeline: TrainPipeline, ) -> None: dataloader = iter(bench_inputs) @@ -303,20 +362,17 @@ def _func_to_benchmark( except StopIteration: break - result = benchmark( + result = benchmark_func( name=pipeline_clazz.__name__, - model=sharded_model, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore num_benchmarks=5, - output_dir="", - warmup_inputs=[], - # pyre-ignore - bench_inputs=bench_inputs, - prof_inputs=[], + num_profiles=2, + profile_dir=profile, world_size=world_size, func_to_benchmark=_func_to_benchmark, - benchmark_func_kwargs={"pipeline": pipeline}, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, rank=rank, - enable_logging=False, ) if rank == 0: print( @@ -324,5 +380,79 @@ def _func_to_benchmark( ) +def single_runner( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + sharding_type: str, + kernel_type: str, + fused_params: Dict[str, Any], + world_size: int, + batches: List[List[ModelInput]], + pipelines: str, + profile: str, +) -> None: + device = torch.device("cuda") + torch.autograd.set_detect_anomaly(True) + model = TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=device, + sparse_device=device, + over_arch_clazz=TestOverArchLarge, + ).to(device) + + optimizer = optim.SGD( + [param for name, param in model.named_parameters() if "sparse" not in name], + lr=0.1, + ) + + bench_inputs = [batch[0] for batch in batches] + for pipeline_clazz in _gen_pipelines(pipelines=pipelines): + if pipeline_clazz == TrainPipelineSemiSync: + # pyre-ignore [28] + pipeline = pipeline_clazz( + model=model, + optimizer=optimizer, + device=device, + start_batch=0, + ) + else: + pipeline = pipeline_clazz( + model=model, + optimizer=optimizer, + device=device, + ) + pipeline.progress(iter(bench_inputs)) + + def _func_to_benchmark( + bench_inputs: List[ModelInput], + model: nn.Module, + pipeline: TrainPipeline, + ) -> None: + dataloader = iter(bench_inputs) + while True: + try: + pipeline.progress(dataloader) + except StopIteration: + break + + result = benchmark_func( + name=pipeline_clazz.__name__, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=profile, + world_size=world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": model, "pipeline": pipeline}, + rank=0, + ) + + print( + f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB" + ) + + if __name__ == "__main__": main()