diff --git a/torchbenchmark/operators/sum/__init__.py b/torchbenchmark/operators/sum/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/sum/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/sum/kernels.py b/torchbenchmark/operators/sum/kernels.py new file mode 100644 index 0000000000..7b5a945e88 --- /dev/null +++ b/torchbenchmark/operators/sum/kernels.py @@ -0,0 +1,31 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def triton_sum_kernel_scalar( + input_ptr, + output_ptr, + M, # number of elements + BLOCK_SIZE_M: tl.constexpr, # number of elements per block +): + pid = tl.program_id(axis=0) # i-th block of input + + block_start = pid * BLOCK_SIZE_M + # offsets have shape equal to input shape + offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block + + # mask has shape equal to input shape + mask = offsets < M # mask out offsets that are out of bounds for input + + # loaded pointers have shape equal to input shape + x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape + + output = tl.sum(x) + + # output_offsets have shape equal to output shape + output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,)) + + # stored pointers have shape equal to output shape + tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py new file mode 100644 index 0000000000..475e7b4b3a --- /dev/null +++ b/torchbenchmark/operators/sum/operator.py @@ -0,0 +1,93 @@ +import argparse +from typing import Callable, Generator, List, Optional, Tuple + +import torch +import triton +import triton.language as tl +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + +from .kernels import triton_sum_kernel_scalar + + +class Operator(BenchmarkOperator): + + DEFAULT_METRICS = ["latency", "accuracy"] + + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): + super().__init__(mode=mode, device=device, extra_args=extra_args) + self.sizes = range(1, 17) + + @register_benchmark() + def triton_sum(self, x: torch.Tensor): + x_1d = x.view(-1) + M = x_1d.shape[0] + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),) + BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2 + + def _inner(): + output = torch.zeros(1, device=x.device, dtype=x.dtype) + + triton_sum_kernel_scalar[grid]( + x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M, + ) + + return output + + return _inner + + @register_benchmark(baseline=True) + def torch_sum(self, x: torch.Tensor): + result = torch.sum(x) + return lambda: result + + def get_x_val(self, example_inputs): + return len(example_inputs[0]) + + def get_x_vals(self) -> List[int]: + x_vals = [] + + x_vals.extend([2**n for n in self.sizes]) + x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0]) + + return x_vals + + def get_input_iter(self) -> Generator: + # reduce to a scalar value + for size in self.get_x_vals(): # 1D matrix + input_1d = torch.randn(size, device=self.device, dtype=self.dtype) + yield (input_1d, ) + + for size in self.get_x_vals(): # 2D matrix + if size < pow(2, 8): # ensure we don't exceed floating point limitations + input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype) + yield (input_2d, ) + + for size in self.get_x_vals(): # 3D matrix + if size < pow(2, 4): # ensure we don't exceed floating point limitations + input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype) + yield (input_2d, ) + + def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: + output = fn() + baseline_output = baseline_fn() + return torch.allclose(output, baseline_output, atol=1e-4) + + @register_metric(skip_baseline=True) + def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): + return [ex.dim() for ex in example_inputs] + + @register_metric() + def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): + gbps = ( + lambda ms: 3 + * example_inputs[0].element_size() + * example_inputs[0].numel() + / ms + * 1e-6 + ) + return list(map(gbps, metrics.latency if metrics.latency else [0]))