Skip to content

Commit

Permalink
Benchmark launch latency
Browse files Browse the repository at this point in the history
Summary: Triton kernel launch latency has been identified as a problem recently; let's track it and drive it down

Reviewed By: sijiac

Differential Revision: D55904363

fbshipit-source-id: 3338a274282b2ee699822433c03785434fff2098
  • Loading branch information
bertmaher authored and facebook-github-bot committed Apr 9, 2024
1 parent 161f2cc commit 509bee1
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
67 changes: 67 additions & 0 deletions torchbenchmark/operators/launch_latency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import triton
import triton.language as tl

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)


@triton.jit
def nop_kernel():
pass


@triton.jit
def nop_with_args_kernel(
t1,
t2,
t3,
t4,
t5,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
c1: tl.constexpr,
c2: tl.constexpr,
c3: tl.constexpr,
c4: tl.constexpr,
c5: tl.constexpr,
):
pass


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["walltime"]

def get_input_iter(self):
yield tuple()
targs = [torch.zeros(1, device="cuda") for _ in range(5)]
iargs = [1 for _ in range(9)]
cargs = [32 for _ in range(5)]
yield tuple([*targs, *iargs, *cargs])

def get_x_val(self, example_inputs) -> float:
return len(example_inputs)

@register_benchmark()
def nop_triton_kernel(self, *args):
if len(args) == 0:
return lambda: nop_kernel[1,]()
return lambda: nop_with_args_kernel[1,](*args)

@register_benchmark(baseline=True)
def nop_python_function(self, *args):
def nop():
pass

return nop
40 changes: 40 additions & 0 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
import argparse
import random
import time
import triton
import torch
import gc
Expand Down Expand Up @@ -36,6 +37,35 @@ class Mode(Enum):
FWD_BWD = 3
FWD_NO_GRAD = 4

def do_bench_walltime(fn, warmup=25, rep=100):
fn()
torch.cuda.synchronize()

start_time = time.perf_counter()
for _ in range(5):
fn()
torch.cuda.synchronize()
end_time = time.perf_counter()
estimate_ms = (end_time - start_time) * 1e3 / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))

# Warm-up
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()

# Benchmark
start_time = time.perf_counter()
for _ in range(n_repeat):
fn()
torch.cuda.synchronize()
end_time = time.perf_counter()
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
return wall_time_ms

@dataclass
class BenchmarkOperatorMetrics:
# latency in ms
Expand All @@ -46,6 +76,8 @@ class BenchmarkOperatorMetrics:
speedup: Optional[float]
# accuracy over baseline
accuracy: Optional[bool]
# wall time
walltime: Optional[float]
# error message
error_msg: Optional[str]
# extra metrics
Expand Down Expand Up @@ -404,6 +436,7 @@ def _do_bench(self,
tflops = []
speedup = None
accuracy = None
walltime = None
error_msg = None
try:
fn = self._get_bm_func(fn_name)
Expand All @@ -417,6 +450,12 @@ def _do_bench(self,
quantiles=quantiles,
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
if "walltime" in self.required_metrics:
walltime = do_bench_walltime(
fn,
warmup=warmup,
rep=rep,
)
if "speedup" in self.required_metrics:
speedup = numpy.median(self.baseline_metrics.latency) / numpy.median(latency) \
if self.baseline_metrics and self.baseline_metrics.latency else None
Expand All @@ -429,6 +468,7 @@ def _do_bench(self,
tflops=None,
speedup=speedup,
accuracy=accuracy,
walltime=walltime,
error_msg=error_msg,
extra_metrics={},
)
Expand Down

0 comments on commit 509bee1

Please sign in to comment.