Skip to content

Commit

Permalink
Add sum reduction operator to TritonBench (#2282)
Browse files Browse the repository at this point in the history
Summary:

Add a Triton reduction kernel for the `sum` operator where `dim=None` to TritonBench, following the [TritonBench guide](https://fb.workplace.com/notes/953949486404240). This implementation works for all matrices being reduced to a scalar value.

To measure accuracy of Triton reduction kernel, add accuracy metric to sum kernel in TritonBench in order to test accuracy of Triton implementation against baseline PyTorch implementation, referencing [`torchbenchmark/operators/gemm/operator.py`](https://www.internalfb.com/code/fbsource/[767bb6faa353685b84f08a39f36fdcf6ca170c85]/fbcode/pytorch/benchmark/torchbenchmark/operators/gemm/operator.py?lines=236). Reset output registers per run of the Triton kernel for accurate Triton output.

To measure performance of the Triton reduction kernel against PyTorch, add gbps metric, referencing [`torchbenchmark/operators/vector_add/operator.py`](https://www.internalfb.com/code/fbsource/[858eda681c7618f9427ba55cef8d4aba712cb26e]/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/operator.py?lines=19).

Referenced the existing [vector_add](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/) and [grouped_gemm](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/grouped_gemm/) TritonBench operators as frameworks for implementation.

See the [TritonBench Operator Coverage Tracker](https://docs.google.com/spreadsheets/d/1091POOPSPsUnlNVEKaz2X_DQXdIwFv-fGOH_g9by-Zo/edit#gid=0) for current operator coverage in TritonBench.

Reviewed By: xuzhao9, davidberard98

Differential Revision: D58048782
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 6, 2024
1 parent f7b4bcc commit 1f8df93
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/sum/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
31 changes: 31 additions & 0 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import triton
import triton.language as tl


@triton.jit
def triton_sum_kernel_scalar(
input_ptr,
output_ptr,
M, # number of elements
BLOCK_SIZE_M: tl.constexpr, # number of elements per block
):
pid = tl.program_id(axis=0) # i-th block of input

block_start = pid * BLOCK_SIZE_M
# offsets have shape equal to input shape
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block

# mask has shape equal to input shape
mask = offsets < M # mask out offsets that are out of bounds for input

# loaded pointers have shape equal to input shape
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape

output = tl.sum(x)

# output_offsets have shape equal to output shape
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,))

# stored pointers have shape equal to output shape
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape
93 changes: 93 additions & 0 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
from typing import Callable, Generator, List, Optional, Tuple

import torch
import triton
import triton.language as tl
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .kernels import triton_sum_kernel_scalar


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(1, 17)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
x_1d = x.view(-1)
M = x_1d.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2

def _inner():
output = torch.zeros(1, device=x.device, dtype=x.dtype)

triton_sum_kernel_scalar[grid](
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M,
)

return output

return _inner

@register_benchmark(baseline=True)
def torch_sum(self, x: torch.Tensor):
result = torch.sum(x)
return lambda: result

def get_x_val(self, example_inputs):
return len(example_inputs[0])

def get_x_vals(self) -> List[int]:
x_vals = []

x_vals.extend([2**n for n in self.sizes])
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])

return x_vals

def get_input_iter(self) -> Generator:
# reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d, )

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 8): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )

for size in self.get_x_vals(): # 3D matrix
if size < pow(2, 4): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-4)

@register_metric(skip_baseline=True)
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
return [ex.dim() for ex in example_inputs]

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: 3
* example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
* 1e-6
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))

0 comments on commit 1f8df93

Please sign in to comment.