From 509bee15efcabfd630925ad97f4c4386aa76e32c Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 9 Apr 2024 06:42:35 -0700 Subject: [PATCH] Benchmark launch latency Summary: Triton kernel launch latency has been identified as a problem recently; let's track it and drive it down Reviewed By: sijiac Differential Revision: D55904363 fbshipit-source-id: 3338a274282b2ee699822433c03785434fff2098 --- .../operators/launch_latency/__init__.py | 67 +++++++++++++++++++ torchbenchmark/util/triton_op.py | 40 +++++++++++ 2 files changed, 107 insertions(+) create mode 100644 torchbenchmark/operators/launch_latency/__init__.py diff --git a/torchbenchmark/operators/launch_latency/__init__.py b/torchbenchmark/operators/launch_latency/__init__.py new file mode 100644 index 0000000000..76f140efa6 --- /dev/null +++ b/torchbenchmark/operators/launch_latency/__init__.py @@ -0,0 +1,67 @@ +import torch +import triton +import triton.language as tl + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + + +@triton.jit +def nop_kernel(): + pass + + +@triton.jit +def nop_with_args_kernel( + t1, + t2, + t3, + t4, + t5, + i1, + i2, + i3, + i4, + i5, + i6, + i7, + i8, + i9, + c1: tl.constexpr, + c2: tl.constexpr, + c3: tl.constexpr, + c4: tl.constexpr, + c5: tl.constexpr, +): + pass + + +class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["walltime"] + + def get_input_iter(self): + yield tuple() + targs = [torch.zeros(1, device="cuda") for _ in range(5)] + iargs = [1 for _ in range(9)] + cargs = [32 for _ in range(5)] + yield tuple([*targs, *iargs, *cargs]) + + def get_x_val(self, example_inputs) -> float: + return len(example_inputs) + + @register_benchmark() + def nop_triton_kernel(self, *args): + if len(args) == 0: + return lambda: nop_kernel[1,]() + return lambda: nop_with_args_kernel[1,](*args) + + @register_benchmark(baseline=True) + def nop_python_function(self, *args): + def nop(): + pass + + return nop diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index f73c7188ae..4396445e8e 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -3,6 +3,7 @@ from enum import Enum import argparse import random +import time import triton import torch import gc @@ -36,6 +37,35 @@ class Mode(Enum): FWD_BWD = 3 FWD_NO_GRAD = 4 +def do_bench_walltime(fn, warmup=25, rep=100): + fn() + torch.cuda.synchronize() + + start_time = time.perf_counter() + for _ in range(5): + fn() + torch.cuda.synchronize() + end_time = time.perf_counter() + estimate_ms = (end_time - start_time) * 1e3 / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + + # Benchmark + start_time = time.perf_counter() + for _ in range(n_repeat): + fn() + torch.cuda.synchronize() + end_time = time.perf_counter() + wall_time_ms = (end_time - start_time) * 1e3 / n_repeat + return wall_time_ms + @dataclass class BenchmarkOperatorMetrics: # latency in ms @@ -46,6 +76,8 @@ class BenchmarkOperatorMetrics: speedup: Optional[float] # accuracy over baseline accuracy: Optional[bool] + # wall time + walltime: Optional[float] # error message error_msg: Optional[str] # extra metrics @@ -404,6 +436,7 @@ def _do_bench(self, tflops = [] speedup = None accuracy = None + walltime = None error_msg = None try: fn = self._get_bm_func(fn_name) @@ -417,6 +450,12 @@ def _do_bench(self, quantiles=quantiles, grad_to_none=self.get_grad_to_none(self.example_inputs), ) + if "walltime" in self.required_metrics: + walltime = do_bench_walltime( + fn, + warmup=warmup, + rep=rep, + ) if "speedup" in self.required_metrics: speedup = numpy.median(self.baseline_metrics.latency) / numpy.median(latency) \ if self.baseline_metrics and self.baseline_metrics.latency else None @@ -429,6 +468,7 @@ def _do_bench(self, tflops=None, speedup=speedup, accuracy=accuracy, + walltime=walltime, error_msg=error_msg, extra_metrics={}, )