Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tritonbench] Benchmark int4 gemm implementations #2261

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchbenchmark/operators/int4_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .int4_gemm import Operator
145 changes: 145 additions & 0 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Compute a bf16 (activation) x int4 (weight) gemm.
Inspired by [gpt-fast](https://github.com/pytorch-labs/gpt-fast)
ATen kernels from tinygemm
Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2
"""

import argparse
import os
import statistics
import torch
import triton.ops
import triton.language as tl

from typing import Any

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .kernel import pack_2xint4, matmul, matmul_kernel


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8

def get_input_iter(self):
def args(B, L, Dout, Din):
x = torch.randn(B, L, Din, device=self.device, dtype=torch.bfloat16)
w = torch.randint(-8, 7, (Din, Dout), device=self.device, dtype=torch.int32)
scales_and_zeros = torch.randn(
Din // self.group_size,
Dout,
2,
device=self.device,
dtype=torch.bfloat16,
)
return (x, w, scales_and_zeros)

# LLama-2 shapes w/ 8-way tensor parallelism.
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
for seq_len in (1, 4096):
for bsz in (1, 4, 16, 64):
for name, (k, n) in name_to_shapes_70b.items():
yield args(bsz, seq_len, n, k)

def get_x_val(self, example_inputs) -> float:
x, w, scales_and_zeros = example_inputs
B, m, k = x.size()
_, n = w.size()
return (B, m, n, k)

@register_benchmark(baseline=True)
def tinygemm(self, x, w, scales_and_zeros):
x = x.reshape(-1, x.size(-1))
w_int4 = torch.ops.aten._convert_weight_to_int4pack(
w.T.contiguous(), self.inner_k_tiles
)
return lambda: torch.ops.aten._weight_int4pack_mm(
x, w_int4, self.group_size, scales_and_zeros
)

@register_benchmark()
def triton(self, x, w, scales_and_zeros):
x = x.reshape(-1, x.size(-1))
w_int4 = pack_2xint4(w).T.contiguous().T
return lambda: matmul(x, w_int4)

@register_metric()
def best_config(self, fn, inputs, metrics):
if "triton" in str(fn):
return str(matmul_kernel.best_config)
return ""

@register_metric()
def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
def nbytes(t):
return t.numel() * t.element_size()

x, w, scale_and_zero = example_inputs
c = fn()

gb = (sum(nbytes(t) for t in (x, scale_and_zero, c)) + nbytes(w) // 8) / 1e9
return list(map(lambda ms: gb / ms * 1e3, metrics.latency))

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b, _ = example_inputs
B, m, k = a.size()
m = B * m
_, n = b.size()
flops = 2 * m * n * k
return [flops / x / 1e12 * 1e3 for x in metrics.latency]

def plot(self):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=[
"B",
"m",
"n",
"k",
], # argument names to use as an x-axis for the plot
x_vals=self.output.x_vals, # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"tinygemm",
"triton",
], # possible values for `line_arg``
line_names=[
"tinygemm",
"triton",
], # label name for the lines
styles=[("blue", "-"), ("green", "-")],
ylabel="tflops", # label name for the y-axis
plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def _plot(B, m, n, k, provider):
tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops")
return tflops

save_path = "/tmp/int4_gemm"

if not os.path.exists(save_path):
os.mkdir(save_path)

_plot.run(show_plots=True, print_data=True, save_path=save_path)
166 changes: 166 additions & 0 deletions torchbenchmark/operators/int4_gemm/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2
"""

import torch
import triton
import triton.language as tl

AUTOTUNE_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
},
num_stages=4,
num_warps=8,
),
]


@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"])
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions.
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
#
# We assume `b` is packed with 2 `int4` elements per K, i.e. it's a
# (K//2)xNx(2xint4) matrix, represented in Triton as (K//2)xNxi8. If K
# is the minor dimension, then stride_bk should logically be 0.5. But
# we don't want a fractional stride! So let the given stride be the
# stride per 2xint4.
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
tl.device_assert(K % BLOCK_SIZE_K == 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K // 2, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_ak = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_ak[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs)
tl.static_assert(b.dtype == tl.int8)

# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
# _4_i8 because the literal "4" is considered an i32, which causes the
# shift operands to be widened to i32.
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
b_lo = (b << _4_i8) >> _4_i8
b_hi = b >> _4_i8
# Workaround: Convert before the join() so that Triton can load the data
# after the join using ldmatrix.
b_f16 = (
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
.permute(0, 2, 1)
.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
)

accumulator += tl.dot(a, b_f16)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk // 2

c = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
assert a.shape[1] == b.shape[0] * 2, "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
_, N = b.shape

c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
)
return c


def pack_2xint4(t):
# Packs a KxNxfp16 matrix into a (K//2)xNx(2xint4) matrix.
t = t.to(torch.int8).reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2)
return (t[0] & 0xF) | (t[1] << 4)
Loading