Skip to content

Commit

Permalink
Add coefficient of variance to the bench mark report.
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Jul 8, 2024
1 parent e721be7 commit c0f624d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
20 changes: 15 additions & 5 deletions benchmarks/xetla_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def extract_kernels(funcs):
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
# add coefficient of the variance.
std = torch.std(times)
mean = torch.mean(times)
cv = std / mean
ret.extend([mean.tolist(), cv.tolist()])
if len(ret) == 1:
ret = ret[0]
return ret
Expand Down Expand Up @@ -240,6 +245,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
y_min = [f'{x}-{label}-min' for x in bench.line_names]
y_max = [f'{x}-{label}-max' for x in bench.line_names]
y_vals += y_mean + y_min + y_max
y_vals += [f'{x}-CV' for x in bench.line_names]
x_names = list(bench.x_names)
df = pd.DataFrame(columns=x_names + y_vals)
for x in bench.x_vals:
Expand All @@ -252,11 +258,11 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
x_args = dict(zip(x_names, x))

row_vals = {}
for label in bench.ylabel:
for label in itertools.chain(bench.ylabel, ["CV"]):
row_vals[label] = ([], [], [])
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
for id, label in enumerate(bench.ylabel):
for id, label in enumerate(itertools.chain(bench.ylabel, ["CV"])):
try:
y_mean, y_min, y_max = ret[id]
except TypeError:
Expand All @@ -266,9 +272,13 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
row_vals[label][2].append(y_max)
rows = []
for label in bench.ylabel:
rows += row_vals[label][0]
rows += row_vals[label][1]
rows += row_vals[label][2]
if len(row_vals[label][0]) > 0:
rows += row_vals[label][0]
if len(row_vals[label][1]) > 0:
rows += row_vals[label][1]
if len(row_vals[label][2]) > 0:
rows += row_vals[label][2]
rows += row_vals["CV"][0]
df.loc[len(df)] = list(x) + rows

if bench.plot_name:
Expand Down
18 changes: 10 additions & 8 deletions benchmarks/xetla_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,28 +130,30 @@ def benchmark(M, N, provider):
x = torch.randn(M, N, device='xpu', dtype=torch.bfloat16)
quantiles = [0.5, 0.0, 1.0]
if provider == 'torch-native':
ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10,
rep=10)
ms, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
warmup=10, rep=10)
if provider == 'triton':
triton_fn = lambda: softmax(x)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
ms, min_ms, max_ms = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
ms, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)

if provider == 'torch-jit':
ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10)
ms, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
rep=10)

if provider == 'xetla':
name = "softmax_shape_{}_{}".format(M, N)
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(x, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
ms, min_ms, max_ms = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
ms, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)

gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
tflops = lambda ms: 4 * x.nelement() * 1e-12 / (ms * 1e-3) # reduce-max, reduce-sum, elem-wise sub, elem-wise div
return (gbps(ms), gbps(max_ms), gbps(min_ms)), (tflops(ms), tflops(max_ms), tflops(min_ms))
gbps = lambda mean: 2 * x.nelement() * x.element_size() * 1e-9 / (mean * 1e-3)
tflops = lambda mean: 4 * x.nelement() * 1e-12 / (mean * 1e-3
) # reduce-max, reduce-sum, elem-wise sub, elem-wise div
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv


if __name__ == "__main__":
Expand Down

0 comments on commit c0f624d

Please sign in to comment.