Skip to content

Commit

Permalink
add trace and single process runner for pipeline benchmark (pytorch#2347
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#2347

# context
* please refer to this [plan doc](https://docs.google.com/document/d/1E45sbCPVA7JzG18BFS0tTOQHMLETGuZxkoupRwTwqkM/edit#heading=h.o7xaxy435ue4)
* add trace for the pipeline benchmark
* add single process runner for the pipeline benchmark
 {F1832319035}

Differential Revision: D61637749
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 28, 2024
1 parent b082638 commit 8fe8afc
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 28 deletions.
121 changes: 121 additions & 0 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,127 @@ def trace_handler(prof) -> None:
)


def benchmark_func(
name: str,
bench_inputs: List[Dict[str, Any]],
prof_inputs: List[Dict[str, Any]],
world_size: int,
profile_dir: str,
num_benchmarks: int,
num_profiles: int,
# pyre-ignore[2]
func_to_benchmark: Any,
benchmark_func_kwargs: Optional[Dict[str, Any]],
rank: int,
device_type: str = "cuda",
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
if device_type == "cuda":
if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)

start = []
end = []
if device_type == "cuda":
# Measure time taken for batches in bench_inputs
start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]

if benchmark_func_kwargs is None:
# Need this to unwrap
benchmark_func_kwargs = {}

times = []
if device_type == "cuda":
for i in range(num_benchmarks):
start[i].record()
func_to_benchmark(bench_inputs, **benchmark_func_kwargs)
end[i].record()
elif device_type == "cpu":
times = timeit.repeat(
lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)

if device_type == "cuda":
if rank == -1:
for di in range(world_size):
torch.cuda.synchronize(di)
else:
torch.cuda.synchronize(rank)

# TODO: First Benchmark Run for Eager Mode produces outlier
# Start counting after first as workaround for standard deviation
if device_type == "cuda":
elapsed_time = torch.tensor(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)
else:
elapsed_time = torch.tensor(times) * 1e3

if device_type == "cuda":
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)

if profile_dir != "":
# Only do profiling if output_dir is set

# pyre-ignore[2]
def trace_handler(prof) -> None:
total_average = prof.profiler.total_average()
logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}")
dir_path: str = profile_dir
if rank == 0:
trace_file: str = f"{dir_path}/trace-{name}.json"
else:
trace_file: str = f"{dir_path}/trace-{name}-{rank}.json"
return # only 1 rank should output in pg case, rank = 0
logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}")
prof.export_chrome_trace(trace_file)

if device_type == "cuda":
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
on_trace_ready=trace_handler,
) as p:
for i in range(num_profiles):
with record_function(f"## profile {i} ##"):
func_to_benchmark(prof_inputs, **benchmark_func_kwargs)
p.step()

if rank == -1:
for di in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{di}"))
else:
torch.cuda.synchronize()

return BenchmarkResult(
short_name=name,
elapsed_time=elapsed_time,
max_mem_allocated=max_mem_allocated,
rank=rank,
)


def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str:
if sharding_type == ShardingType.TABLE_WISE:
name = "tw-sharded"
Expand Down
Loading

0 comments on commit 8fe8afc

Please sign in to comment.