Skip to content

Commit

Permalink
Update the Triton softmax micro-bench. (#1207)
Browse files Browse the repository at this point in the history
1. Modify the softmax kernel for better performance on N < 1024 cases.
2. Use the synchronize submitting by default for the benchmark.
3. Align the tile configuration of the XeTLA kernel and Triton kernel.
  • Loading branch information
chengjunlu authored May 30, 2024
1 parent 1142f12 commit ceadb1b
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 39 deletions.
6 changes: 4 additions & 2 deletions benchmarks/xetla_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def synchronize():


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device='xpu'):
device='xpu', sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -78,6 +78,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if sync_submitting:
synchronize()
# record time of `fn`
with record_function("__profile_kernel_of_func"):
fn()
Expand Down Expand Up @@ -288,7 +290,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b

if print_data:
print(bench.plot_name + ':')
print(df)
print(df.to_string())
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
index=False)
Expand Down
74 changes: 48 additions & 26 deletions benchmarks/xetla_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import triton
import triton.language as tl
from triton.runtime import driver

import xetla_benchmark
import xetla_benchmark.xetla_kernel as xetla_kernel
Expand Down Expand Up @@ -42,64 +43,83 @@ def naive_softmax(x):

@triton.autotune(
configs=[
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=4),
triton.Config({"threads_per_warp": 32}, num_warps=32),
triton.Config({"threads_per_warp": 32}, num_warps=16),
triton.Config({"threads_per_warp": 32}, num_warps=8),
triton.Config({"threads_per_warp": 32}, num_warps=4),
triton.Config({"threads_per_warp": 16}, num_warps=64),
triton.Config({"threads_per_warp": 16}, num_warps=32),
triton.Config({"threads_per_warp": 16}, num_warps=16),
triton.Config({"threads_per_warp": 16}, num_warps=8),
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=['n_cols', 'BLOCK_SIZE'],
key=['BLOCK_SIZE_X', 'BLOCK_SIZE_Y'],
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
row_idx = tl.program_id(0) * BLOCK_SIZE_Y
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
col_offsets = tl.arange(0, BLOCK_SIZE_X)
row_offsets = tl.arange(0, BLOCK_SIZE_Y)
offsets = col_offsets[None, :] + row_offsets[:, None] * input_row_stride
input_ptrs = row_start_ptr + offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
mask = col_offsets[None, :] < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
row_minus_max = row - tl.max(row, axis=1)[:, None]
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
denominator = tl.sum(numerator, axis=1)[:, None]
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
output_ptrs = output_row_start_ptr + offsets
tl.store(output_ptrs, softmax_output, mask=mask)


device = torch.xpu.current_device()
properties = driver.active.utils.get_device_properties(device)
MAX_WORK_GROUP_SIZE = properties["max_work_group_size"]


def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)

# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE_X = triton.next_power_of_2(n_cols)
BLOCK_SIZE_Y = MAX_WORK_GROUP_SIZE // BLOCK_SIZE_X
BLOCK_SIZE_Y = BLOCK_SIZE_Y if BLOCK_SIZE_Y > 0 else 1

# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE)
# Create a number of persistent programs.
softmax_kernel[(n_rows // BLOCK_SIZE_Y, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y)
return y


@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256, 1024, 2048, 4096], # different possible values for `x_name`
x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=[
'triton',
'torch-native',
'torch-jit',
# 'torch-native',
# 'torch-jit',
'xetla',
], # possible values for `line_arg``
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
"Xetla",
# "Torch (native)",
# "Torch (jit)",
"XeTLA",
], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--'), ('black', ':')], # line styles
ylabel="GB/s", # label name for the y-axis
Expand All @@ -108,7 +128,7 @@ def softmax(x):
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='xpu', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
quantiles = [0.5, 0.0, 1.0]
if provider == 'torch-native':
ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10,
rep=10)
Expand All @@ -120,16 +140,18 @@ def benchmark(M, N, provider):

if provider == 'torch-jit':
ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10)

if provider == 'xetla':
name = "softmax_shape_{}_{}".format(M, N)
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(x, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
ms, min_ms, max_ms = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)

gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
benchmark.run(show_plots=True, print_data=True)
benchmark.run(show_plots=False, print_data=True)
6 changes: 6 additions & 0 deletions benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ PYBIND11_MODULE(xetla_kernel, m) {
"softmax forward (XeTLA)");
m.def("softmax_shape_4096_4096", &softmax<mat1_4096x4096_bf16_cfg0>,
"softmax forward (XeTLA)");
m.def("softmax_shape_4096_8192", &softmax<mat1_4096x8k_bf16_cfg0>,
"softmax forward (XeTLA)");
m.def("softmax_shape_4096_16384", &softmax<mat1_4096x16k_bf16_cfg0>,
"softmax forward (XeTLA)");
m.def("softmax_shape_4096_32768", &softmax<mat1_4096x32k_bf16_cfg0>,
"softmax forward (XeTLA)");
// bgemm: M=N=K [256, 512 ... 4096]
m.def("bgemm_shape_256_256_256", &bgemm<Test_256x256x256_row_row>,
"bgemm (XeTLA)");
Expand Down
61 changes: 50 additions & 11 deletions benchmarks/xetla_kernel/softmax/softmax_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class mat1_4096x1024_bf16_cfg0 {
static constexpr size_t mat_n = 1024;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 4; // 1 4 8 16
static constexpr size_t sg_n = mat_n;
static constexpr size_t sg_m = 1;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 8;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
Expand All @@ -48,10 +48,10 @@ class mat1_4096x2048_bf16_cfg0 {
public:
static constexpr size_t mat_n = 2048;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n / 2;
static constexpr size_t wg_m = 4; // 1 4 8 16
static constexpr size_t sg_n = mat_n / 2;
static constexpr size_t sg_m = 1;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 8;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
Expand All @@ -61,10 +61,49 @@ class mat1_4096x4096_bf16_cfg0 {
public:
static constexpr size_t mat_n = 4096;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n / 2;
static constexpr size_t wg_m = 4; // 1 4 8 16
static constexpr size_t sg_n = mat_n / 2;
static constexpr size_t sg_m = 1;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 32;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
};

class mat1_4096x8k_bf16_cfg0 {
public:
static constexpr size_t mat_n = 4096 * 2;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 32;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
};

class mat1_4096x16k_bf16_cfg0 {
public:
static constexpr size_t mat_n = 4096 * 4;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 32;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
};

class mat1_4096x32k_bf16_cfg0 {
public:
static constexpr size_t mat_n = 4096 * 8;
static constexpr size_t mat_m = 4096;
static constexpr size_t wg_n = mat_n;
static constexpr size_t wg_m = 1; // 1 4 8 16
static constexpr size_t sg_n = wg_n / 32;
static constexpr size_t sg_m = wg_m;
using data_type_in = sycl::ext::oneapi::bfloat16;
using data_type_out = sycl::ext::oneapi::bfloat16;
using data_type_acc = float;
Expand Down

0 comments on commit ceadb1b

Please sign in to comment.