From 1f8df93db7f953f0478354f0686ded43708b3ee6 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Thu, 6 Jun 2024 13:23:21 -0700 Subject: [PATCH] Add sum reduction operator to TritonBench (#2282) Summary: Add a Triton reduction kernel for the `sum` operator where `dim=None` to TritonBench, following the [TritonBench guide](https://fb.workplace.com/notes/953949486404240). This implementation works for all matrices being reduced to a scalar value. To measure accuracy of Triton reduction kernel, add accuracy metric to sum kernel in TritonBench in order to test accuracy of Triton implementation against baseline PyTorch implementation, referencing [`torchbenchmark/operators/gemm/operator.py`](https://www.internalfb.com/code/fbsource/[767bb6faa353685b84f08a39f36fdcf6ca170c85]/fbcode/pytorch/benchmark/torchbenchmark/operators/gemm/operator.py?lines=236). Reset output registers per run of the Triton kernel for accurate Triton output. To measure performance of the Triton reduction kernel against PyTorch, add gbps metric, referencing [`torchbenchmark/operators/vector_add/operator.py`](https://www.internalfb.com/code/fbsource/[858eda681c7618f9427ba55cef8d4aba712cb26e]/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/operator.py?lines=19). Referenced the existing [vector_add](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/) and [grouped_gemm](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/grouped_gemm/) TritonBench operators as frameworks for implementation. See the [TritonBench Operator Coverage Tracker](https://docs.google.com/spreadsheets/d/1091POOPSPsUnlNVEKaz2X_DQXdIwFv-fGOH_g9by-Zo/edit#gid=0) for current operator coverage in TritonBench. Reviewed By: xuzhao9, davidberard98 Differential Revision: D58048782 --- torchbenchmark/operators/sum/__init__.py | 1 + torchbenchmark/operators/sum/kernels.py | 31 ++++++++ torchbenchmark/operators/sum/operator.py | 93 ++++++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 torchbenchmark/operators/sum/__init__.py create mode 100644 torchbenchmark/operators/sum/kernels.py create mode 100644 torchbenchmark/operators/sum/operator.py diff --git a/torchbenchmark/operators/sum/__init__.py b/torchbenchmark/operators/sum/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/sum/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/sum/kernels.py b/torchbenchmark/operators/sum/kernels.py new file mode 100644 index 0000000000..7b5a945e88 --- /dev/null +++ b/torchbenchmark/operators/sum/kernels.py @@ -0,0 +1,31 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def triton_sum_kernel_scalar( + input_ptr, + output_ptr, + M, # number of elements + BLOCK_SIZE_M: tl.constexpr, # number of elements per block +): + pid = tl.program_id(axis=0) # i-th block of input + + block_start = pid * BLOCK_SIZE_M + # offsets have shape equal to input shape + offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block + + # mask has shape equal to input shape + mask = offsets < M # mask out offsets that are out of bounds for input + + # loaded pointers have shape equal to input shape + x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape + + output = tl.sum(x) + + # output_offsets have shape equal to output shape + output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,)) + + # stored pointers have shape equal to output shape + tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py new file mode 100644 index 0000000000..475e7b4b3a --- /dev/null +++ b/torchbenchmark/operators/sum/operator.py @@ -0,0 +1,93 @@ +import argparse +from typing import Callable, Generator, List, Optional, Tuple + +import torch +import triton +import triton.language as tl +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + +from .kernels import triton_sum_kernel_scalar + + +class Operator(BenchmarkOperator): + + DEFAULT_METRICS = ["latency", "accuracy"] + + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): + super().__init__(mode=mode, device=device, extra_args=extra_args) + self.sizes = range(1, 17) + + @register_benchmark() + def triton_sum(self, x: torch.Tensor): + x_1d = x.view(-1) + M = x_1d.shape[0] + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),) + BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2 + + def _inner(): + output = torch.zeros(1, device=x.device, dtype=x.dtype) + + triton_sum_kernel_scalar[grid]( + x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M, + ) + + return output + + return _inner + + @register_benchmark(baseline=True) + def torch_sum(self, x: torch.Tensor): + result = torch.sum(x) + return lambda: result + + def get_x_val(self, example_inputs): + return len(example_inputs[0]) + + def get_x_vals(self) -> List[int]: + x_vals = [] + + x_vals.extend([2**n for n in self.sizes]) + x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0]) + + return x_vals + + def get_input_iter(self) -> Generator: + # reduce to a scalar value + for size in self.get_x_vals(): # 1D matrix + input_1d = torch.randn(size, device=self.device, dtype=self.dtype) + yield (input_1d, ) + + for size in self.get_x_vals(): # 2D matrix + if size < pow(2, 8): # ensure we don't exceed floating point limitations + input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype) + yield (input_2d, ) + + for size in self.get_x_vals(): # 3D matrix + if size < pow(2, 4): # ensure we don't exceed floating point limitations + input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype) + yield (input_2d, ) + + def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: + output = fn() + baseline_output = baseline_fn() + return torch.allclose(output, baseline_output, atol=1e-4) + + @register_metric(skip_baseline=True) + def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): + return [ex.dim() for ex in example_inputs] + + @register_metric() + def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): + gbps = ( + lambda ms: 3 + * example_inputs[0].element_size() + * example_inputs[0].numel() + / ms + * 1e-6 + ) + return list(map(gbps, metrics.latency if metrics.latency else [0]))