Skip to content

Commit

Permalink
tritonbench bf16xint16 matmul template (#2348)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2348

Overall context: Before looking further into the bf16xint4 matmul, I'm planning to look into a bf16xint16 matmul first. The idea of this matmul is that it will just be the same as a bf16xbf16 matmul, except the second operand needs to be casted from int16 to bf16 in the triton kernel before executing.

This PR: is NOT fully functional yet. It's just implemented this way to make review easier.

There's 3 kernels that will be benchmarked here:
1. bf16xbf16 triton kernel - I've selected this kernel as the "baseline" because, ideally, we'd like the bf16xint16 kernel to be as close as possible to this kernel.
2. bf16xint16 triton kernel - this is NOT implemented yet, will be implemented in the follow-up PR.
3. bf16x(convert(int16 -> bf16)) triton kernel - i.e. convert the int16->bf16, write to global memory, and then run the bf16xbf16 kernel.

Differential Revision:
D59234085

imported-using-ghimport

D59234085

Test Plan: Imported from OSS

Reviewed By: xuzhao9

Pulled By: davidberard98

fbshipit-source-id: 75a493dbd78ee1aa1f63926f6dd61a2e7388816c
  • Loading branch information
davidberard98 authored and facebook-github-bot committed Sep 11, 2024
1 parent 8351474 commit ebd00aa
Show file tree
Hide file tree
Showing 3 changed files with 655 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/bf16xint16_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bf16xint16_gemm import Operator
158 changes: 158 additions & 0 deletions torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Compute a bf16 (activation) x int16 (weight) gemm.
A stepping stone to a fast int4_gemm (another TritonBench kernel)
bf16xbf16 baseline implementation taken from the triton tutorial
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
and the bf16xint16 implementation is a modified version of the same
tutorial kernel.
The benchmarking file (i.e. this file) is mostly copied from the
int4_gemm benchmarking file.
"""

import argparse
import os
import statistics

from typing import Any, List, Optional

import torch
import triton
import triton.language as tl

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .kernel import (
bf16xbf16_matmul,
bf16xbf16_matmul_kernel,
bf16xint16_matmul,
bf16xint16_matmul_kernel,
)


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args=tb_args, extra_args=extra_args)
# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8

def get_input_iter(self):
def args(B, Dout, Din):
x = torch.randn(B, Din, device=self.device, dtype=torch.bfloat16)
w = torch.randint(
-(2**15),
2**15 - 1,
(Din, Dout),
device=self.device,
dtype=torch.int16,
)
return (x, w)

# LLama-2 shapes w/ 8-way tensor parallelism.
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
for bsz in (1, 4, 16, 64, 256, 1024, 2**12, 2**14, 2**16):
for name, (k, n) in name_to_shapes_70b.items():
yield args(bsz, n, k)

def get_x_val(self, example_inputs) -> float:
x, w = example_inputs
m, k = x.size()
_, n = w.size()
return (m, n, k)

@register_benchmark(baseline=True)
def bf16xbf16(self, x, w):
x = x.reshape(-1, x.size(-1))
w_bf16 = w.to(torch.bfloat16)
return lambda: bf16xbf16_matmul(x, w_bf16)

@register_benchmark()
def bf16xint16(self, x, w):
x = x.reshape(-1, x.size(-1))
# TODO(davidberard98) fix this to pass in an int16
w = w.to(torch.bfloat16)
return lambda: bf16xint16_matmul(x, w)

@register_benchmark()
def bf16xint16_casted(self, x, w):
x = x.reshape(-1, x.size(-1))
return lambda: bf16xbf16_matmul(x, w.to(torch.bfloat16))

@register_metric()
def best_config(self, fn, inputs, metrics):
if "bf16xbf16" in str(fn):
return str(bf16xbf16_matmul_kernel.best_config)
if "bf16xint16" in str(fn) and "casted" not in str(fn):
return str(bf16xint16_matmul_kernel.best_config)
return ""

@register_metric()
def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
def nbytes(t):
return t.numel() * t.element_size()

x, w = example_inputs
c = fn()

gb = (sum(nbytes(t) for t in (x, c)) + nbytes(w) // 8) / 1e9
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3

def plot(self):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=[
"B",
"m",
"n",
"k",
], # argument names to use as an x-axis for the plot
x_vals=self.output.x_vals, # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"torch",
"triton",
], # possible values for `line_arg``
line_names=[
"torch",
"triton",
], # label name for the lines
styles=[("blue", "-"), ("green", "-")],
ylabel="tflops", # label name for the y-axis
plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def _plot(B, m, n, k, provider):
tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops")
return tflops

save_path = "/tmp/bf16xint16_gemm"

if not os.path.exists(save_path):
os.mkdir(save_path)

_plot.run(show_plots=True, print_data=True, save_path=save_path)
Loading

0 comments on commit ebd00aa

Please sign in to comment.