-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tritonbench bf16xint16 matmul template (#2348)
Summary: Pull Request resolved: #2348 Overall context: Before looking further into the bf16xint4 matmul, I'm planning to look into a bf16xint16 matmul first. The idea of this matmul is that it will just be the same as a bf16xbf16 matmul, except the second operand needs to be casted from int16 to bf16 in the triton kernel before executing. This PR: is NOT fully functional yet. It's just implemented this way to make review easier. There's 3 kernels that will be benchmarked here: 1. bf16xbf16 triton kernel - I've selected this kernel as the "baseline" because, ideally, we'd like the bf16xint16 kernel to be as close as possible to this kernel. 2. bf16xint16 triton kernel - this is NOT implemented yet, will be implemented in the follow-up PR. 3. bf16x(convert(int16 -> bf16)) triton kernel - i.e. convert the int16->bf16, write to global memory, and then run the bf16xbf16 kernel. Differential Revision: D59234085 imported-using-ghimport D59234085 Test Plan: Imported from OSS Reviewed By: xuzhao9 Pulled By: davidberard98 fbshipit-source-id: 75a493dbd78ee1aa1f63926f6dd61a2e7388816c
- Loading branch information
1 parent
8351474
commit ebd00aa
Showing
3 changed files
with
655 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bf16xint16_gemm import Operator |
158 changes: 158 additions & 0 deletions
158
torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
""" | ||
Compute a bf16 (activation) x int16 (weight) gemm. | ||
A stepping stone to a fast int4_gemm (another TritonBench kernel) | ||
bf16xbf16 baseline implementation taken from the triton tutorial | ||
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | ||
and the bf16xint16 implementation is a modified version of the same | ||
tutorial kernel. | ||
The benchmarking file (i.e. this file) is mostly copied from the | ||
int4_gemm benchmarking file. | ||
""" | ||
|
||
import argparse | ||
import os | ||
import statistics | ||
|
||
from typing import Any, List, Optional | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
from .kernel import ( | ||
bf16xbf16_matmul, | ||
bf16xbf16_matmul_kernel, | ||
bf16xint16_matmul, | ||
bf16xint16_matmul_kernel, | ||
) | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
DEFAULT_METRICS = ["tflops", "gbps", "latency"] | ||
|
||
def __init__( | ||
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None | ||
): | ||
super().__init__(tb_args=tb_args, extra_args=extra_args) | ||
# `Group size` and `inner K tiles` are defaults from gpt-fast. | ||
self.group_size = 32 | ||
self.inner_k_tiles = 8 | ||
|
||
def get_input_iter(self): | ||
def args(B, Dout, Din): | ||
x = torch.randn(B, Din, device=self.device, dtype=torch.bfloat16) | ||
w = torch.randint( | ||
-(2**15), | ||
2**15 - 1, | ||
(Din, Dout), | ||
device=self.device, | ||
dtype=torch.int16, | ||
) | ||
return (x, w) | ||
|
||
# LLama-2 shapes w/ 8-way tensor parallelism. | ||
name_to_shapes_70b = { | ||
"attn.wqkv": (8192, 1280), | ||
"attn.w0": (1024, 8192), | ||
"ffn.w13": (8192, 7168), | ||
"ffn.w2": (3584, 8192), | ||
} | ||
for bsz in (1, 4, 16, 64, 256, 1024, 2**12, 2**14, 2**16): | ||
for name, (k, n) in name_to_shapes_70b.items(): | ||
yield args(bsz, n, k) | ||
|
||
def get_x_val(self, example_inputs) -> float: | ||
x, w = example_inputs | ||
m, k = x.size() | ||
_, n = w.size() | ||
return (m, n, k) | ||
|
||
@register_benchmark(baseline=True) | ||
def bf16xbf16(self, x, w): | ||
x = x.reshape(-1, x.size(-1)) | ||
w_bf16 = w.to(torch.bfloat16) | ||
return lambda: bf16xbf16_matmul(x, w_bf16) | ||
|
||
@register_benchmark() | ||
def bf16xint16(self, x, w): | ||
x = x.reshape(-1, x.size(-1)) | ||
# TODO(davidberard98) fix this to pass in an int16 | ||
w = w.to(torch.bfloat16) | ||
return lambda: bf16xint16_matmul(x, w) | ||
|
||
@register_benchmark() | ||
def bf16xint16_casted(self, x, w): | ||
x = x.reshape(-1, x.size(-1)) | ||
return lambda: bf16xbf16_matmul(x, w.to(torch.bfloat16)) | ||
|
||
@register_metric() | ||
def best_config(self, fn, inputs, metrics): | ||
if "bf16xbf16" in str(fn): | ||
return str(bf16xbf16_matmul_kernel.best_config) | ||
if "bf16xint16" in str(fn) and "casted" not in str(fn): | ||
return str(bf16xint16_matmul_kernel.best_config) | ||
return "" | ||
|
||
@register_metric() | ||
def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float: | ||
def nbytes(t): | ||
return t.numel() * t.element_size() | ||
|
||
x, w = example_inputs | ||
c = fn() | ||
|
||
gb = (sum(nbytes(t) for t in (x, c)) + nbytes(w) // 8) / 1e9 | ||
return gb / metrics.latency * 1e3 | ||
|
||
@register_metric() | ||
def tflops( | ||
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics | ||
) -> float: | ||
a, b = example_inputs | ||
m, k = a.size() | ||
_, n = b.size() | ||
flops = 2 * m * n * k | ||
return flops / metrics.latency / 1e12 * 1e3 | ||
|
||
def plot(self): | ||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=[ | ||
"B", | ||
"m", | ||
"n", | ||
"k", | ||
], # argument names to use as an x-axis for the plot | ||
x_vals=self.output.x_vals, # different possible values for `x_name` | ||
line_arg="provider", # argument name whose value corresponds to a different line in the plot | ||
line_vals=[ | ||
"torch", | ||
"triton", | ||
], # possible values for `line_arg`` | ||
line_names=[ | ||
"torch", | ||
"triton", | ||
], # label name for the lines | ||
styles=[("blue", "-"), ("green", "-")], | ||
ylabel="tflops", # label name for the y-axis | ||
plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot. | ||
args={}, # values for function arguments not in `x_names` and `y_name` | ||
) | ||
) | ||
def _plot(B, m, n, k, provider): | ||
tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops") | ||
return tflops | ||
|
||
save_path = "/tmp/bf16xint16_gemm" | ||
|
||
if not os.path.exists(save_path): | ||
os.mkdir(save_path) | ||
|
||
_plot.run(show_plots=True, print_data=True, save_path=save_path) |
Oops, something went wrong.