Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sum reduction operator to TritonBench #2282

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchbenchmark/operators/sum/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
31 changes: 31 additions & 0 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import triton
import triton.language as tl


@triton.jit
def triton_sum_kernel_scalar(
input_ptr,
output_ptr,
M, # number of elements
BLOCK_SIZE_M: tl.constexpr, # number of elements per block
):
pid = tl.program_id(axis=0) # i-th block of input

block_start = pid * BLOCK_SIZE_M
# offsets have shape equal to input shape
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block

# mask has shape equal to input shape
mask = offsets < M # mask out offsets that are out of bounds for input

# loaded pointers have shape equal to input shape
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape

output = tl.sum(x)

# output_offsets have shape equal to output shape
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,))

# stored pointers have shape equal to output shape
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape
93 changes: 93 additions & 0 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
from typing import Callable, Generator, List, Optional, Tuple

import torch
import triton
import triton.language as tl
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .kernels import triton_sum_kernel_scalar


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(1, 17)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
x_1d = x.view(-1)
M = x_1d.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2

def _inner():
output = torch.zeros(1, device=x.device, dtype=x.dtype)

triton_sum_kernel_scalar[grid](
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M,
)

return output

return _inner

@register_benchmark(baseline=True)
def torch_sum(self, x: torch.Tensor):
result = torch.sum(x)
return lambda: result

def get_x_val(self, example_inputs):
return len(example_inputs[0])

def get_x_vals(self) -> List[int]:
x_vals = []

x_vals.extend([2**n for n in self.sizes])
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])

return x_vals

def get_input_iter(self) -> Generator:
# reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d, )

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 8): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )

for size in self.get_x_vals(): # 3D matrix
if size < pow(2, 4): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-4)

@register_metric(skip_baseline=True)
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
return [ex.dim() for ex in example_inputs]

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: 3
* example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
* 1e-6
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))
Loading