Skip to content

Commit

Permalink
Add launch latency benchmarks for triton.CompiledKernel and inductor
Browse files Browse the repository at this point in the history
Summary:
There are a number of more detailed views into launch latency that we can get, in addition to the path we get from `triton.JitFunction`:
- `triton.compiler.CompiledKernel`, which is the lowest-level interface used by
  triton
- Inductor's `CachingAutotuner.run`, which is the lowest-level lauch interface
  used by inductor
- launching a mostly-nop inductor kernel (can't be truly nop because inductor
  won't generate a kernel with nothing in it)

Reviewed By: xuzhao9, chenyang78

Differential Revision: D56073036

fbshipit-source-id: c72b80eb016a5c2ea27717664e8a1ff0f35c705a
  • Loading branch information
bertmaher authored and facebook-github-bot committed Apr 14, 2024
1 parent 5c0d0ce commit 9ff1725
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions torchbenchmark/operators/launch_latency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
import triton
import triton.language as tl
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor import triton_heuristics
from torch._inductor.codecache import AsyncCompile

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand Down Expand Up @@ -40,6 +43,55 @@ def nop_with_args_kernel(
pass


@torch.compile
def trivial_add_kernel(*args):
return sum([torch.tensor(1.0, device="cuda"), *args])


async_compile = AsyncCompile()

inductor_nop = async_compile.triton(
"inductor_nop",
"""
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_heuristics
@triton_heuristics.pointwise(
size_hints=[1],
triton_meta={'signature': {0: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(), equal_to_1=())]},
)
@triton.jit
def inductor_nop(x):
pass
""",
device_str="cuda",
)


inductor_nop_args = async_compile.triton(
"inductor_nop_args",
"""
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_heuristics
@triton_heuristics.pointwise(
size_hints=[1],
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9: 'i32', 10: 'i32', 11: 'i32', 12: 'i32', 13: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(5, 6, 7, 8, 9, 10, 11, 12, 13))]},
)
@triton.jit
def inductor_nop_args(t1, t2, t3, t4, t5, i1, i2, i3, i4, i5, i6, i7, i8, i9):
pass
""",
device_str="cuda",
)


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["walltime"]

Expand All @@ -59,6 +111,41 @@ def nop_triton_kernel(self, *args):
return lambda: nop_kernel[1,]()
return lambda: nop_with_args_kernel[1,](*args)

@register_benchmark()
def nop_triton_compiled_kernel_run(self, *args):
if len(args) == 0:
bin = nop_kernel[1,]()

else:
bin = nop_with_args_kernel[1,](*args)
args = args[:-5] # remove tl.constexpr args
function = bin.function
metadata = (
bin.packed_metadata if hasattr(bin, "packed_metadata") else bin.metadata
)
if hasattr(triton.compiler.CompiledKernel, "launch_metadata"):
return lambda: bin.run(
1, 1, 1, 0, function, metadata, None, None, None, *args
)
else:
return lambda: bin.run(
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, function, None, None, metadata, *args
)

@register_benchmark()
def nop_inductor_kernel_run(self, *args):
stream = get_raw_stream(0)
grid = triton_heuristics.grid(1)

if len(args) == 0:
return lambda: inductor_nop.run(1, grid=grid, stream=stream)
args = args[:-5]
return lambda: inductor_nop_args.run(*args, grid=grid, stream=stream)

@register_benchmark()
def nop_inductor_kernel(self, *args):
return lambda: trivial_add_kernel(*args)

@register_benchmark(baseline=True)
def nop_python_function(self, *args):
def nop():
Expand Down

0 comments on commit 9ff1725

Please sign in to comment.