Skip to content

Commit

Permalink
bf16xint16_gemm operator: add --transpose option (#2466)
Browse files Browse the repository at this point in the history
Summary:
`--transpose` will make this benchmark test a int16 x bf16 mm instead of a bf16 x int16.

This matters for H100, because the wgmma instruction can take registers only on the LHS. So int16 x bf16 is probably the easier one to support efficiently.

Pull Request resolved: #2466

Test Plan:
In OSS: ran `python run_benchmark.py triton --op bf16xint16_gemm --transpose`

Internally, ran `buck2 run mode/opt //pytorch/benchmark:triton -- --op bf16xint16_gemm --transpose`

Internally, we run into the issue fixed by triton-lang/triton#4695; but otherwise, they both run.

Reviewed By: aakhundov

Differential Revision: D63294109

Pulled By: davidberard98

fbshipit-source-id: 3ea05bb09e62f51c405ae538726caf80e1ba0d63
  • Loading branch information
davidberard98 authored and facebook-github-bot committed Sep 24, 2024
1 parent 6a089a4 commit 0ab0e47
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
42 changes: 37 additions & 5 deletions torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args=tb_args, extra_args=extra_args)

parser = argparse.ArgumentParser()
parser.add_argument(
"--transpose",
action="store_true",
help="Instead of computing A @ B, compute B.T @ A.T.",
)

gemm_args = parser.parse_args(extra_args)

# Normally, we compute x @ w, where x is bf16 and w is int16.
# If transposed, we compute w.T @ x.T, where x.T/w.T are contiguous and w.T is int16.
# Motivation: on H100, only one the lhs of the wgmma instruction can be read from registers,
# so the order (which input is int16) matters.
self.transpose = gemm_args.transpose

# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8
Expand All @@ -55,6 +71,12 @@ def args(B, Dout, Din):
device=self.device,
dtype=torch.int16,
)

if self.transpose:
# transpose logic below is only valid for 2D tensors.
assert x.dim() == 2
assert w.dim() == 2
return (w.T.contiguous(), x.T.contiguous())
return (x, w)

# LLama-2 shapes w/ 8-way tensor parallelism.
Expand All @@ -64,6 +86,10 @@ def args(B, Dout, Din):
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}

yield args(2**16, 1280, 8192)
return

for bsz in (1, 4, 16, 64, 256, 1024, 2**12, 2**14, 2**16):
for name, (k, n) in name_to_shapes_70b.items():
yield args(bsz, n, k)
Expand All @@ -76,19 +102,25 @@ def get_x_val(self, example_inputs) -> float:

@register_benchmark(baseline=True)
def bf16xbf16(self, x, w):
x = x.reshape(-1, x.size(-1))
w_bf16 = w.to(torch.bfloat16)
return lambda: bf16xbf16_matmul(x, w_bf16)
x_bf16 = x.to(torch.bfloat16) if self.transpose else x
w_bf16 = w if self.transpose else w.to(torch.bfloat16)
return lambda: bf16xbf16_matmul(x_bf16, w_bf16)

@register_benchmark()
def bf16xint16(self, x, w):
x = x.reshape(-1, x.size(-1))
return lambda: bf16xint16_matmul(x, w)
return lambda: bf16xint16_matmul(x, w, transpose=self.transpose)

@register_benchmark()
def bf16xint16_casted(self, x, w):
x = x.reshape(-1, x.size(-1))
return lambda: bf16xbf16_matmul(x, w.to(torch.bfloat16))

def fn():
x_bf16 = x.to(torch.bfloat16) if self.transpose else x
w_bf16 = w if self.transpose else w.to(torch.bfloat16)
return bf16xbf16_matmul(x_bf16, w_bf16)

return fn

@register_metric()
def best_config(self, fn, inputs, metrics):
Expand Down
18 changes: 15 additions & 3 deletions torchbenchmark/operators/bf16xint16_gemm/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def bf16xint16_matmul_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
TRANSPOSE: tl.constexpr, # if true, assume a_ptr is int16; otherwise assume b_ptr is int16
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Expand Down Expand Up @@ -421,9 +422,19 @@ def bf16xint16_matmul_kernel(
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0)
b_bf16 = b.to(tl.bfloat16)
if TRANSPOSE:
tl.static_assert(a.dtype == tl.int16)
tl.static_assert(b.dtype == tl.bfloat16)
a_bf16 = a.to(tl.bfloat16)
b_bf16 = b
else:
tl.static_assert(a.dtype == tl.bfloat16)
tl.static_assert(b.dtype == tl.int16)
a_bf16 = a
b_bf16 = b.to(tl.bfloat16)

# We accumulate along the K dimension.
accumulator = tl.dot(a, b_bf16, accumulator)
accumulator = tl.dot(a_bf16, b_bf16, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
Expand Down Expand Up @@ -469,7 +480,7 @@ def bf16xbf16_matmul(a, b):
return c


def bf16xint16_matmul(a, b):
def bf16xint16_matmul(a, b, transpose=False):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
Expand All @@ -494,5 +505,6 @@ def bf16xint16_matmul(a, b):
b.stride(1), #
c.stride(0),
c.stride(1), #
TRANSPOSE=transpose,
)
return c

0 comments on commit 0ab0e47

Please sign in to comment.