Skip to content

Commit

Permalink
Add INT4-FP8 rowwise matmul tests (pytorch#2622)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2622

- Add INT4-FP8 rowwise matmul tests with cudagraph and eager
- Change number of groups to group size to align with E2E variable
- Fix a minor dtype issue

Reviewed By: jwfromm

Differential Revision: D57677844

fbshipit-source-id: 78578490025d72a5d73dd207fce4518304283de0
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed May 24, 2024
1 parent 4285088 commit 63ca6dc
Showing 1 changed file with 77 additions and 1 deletion.
78 changes: 77 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,42 @@ def fp8_col_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return xq, scale.reciprocal().to(torch.float32)


def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
N = x.shape[0]
K = x.shape[1]
assert K >= group_size and K % group_size == 0
num_groups = K // group_size
x = x.view(N, group_size, num_groups).to(torch.float)
# Zero point should be values in each row closest to zero.
zp_ind = torch.argmin(torch.abs(x), dim=1, keepdim=True)
zero_point = torch.gather(x, 1, zp_ind) # + (.5 / scale)
x_centered = x - zero_point
scale = 7.5 / torch.max(torch.abs(x_centered), dim=1, keepdim=True).values
xq = torch.clamp(torch.round((x_centered * scale) - 0.5), min=-8, max=7)
return (
xq.to(torch.int8).view(N, -1).contiguous(),
scale.reciprocal().view(N, -1).contiguous(), # pyre-ignore
(zero_point + (0.5 / scale)).view(N, -1).contiguous(),
)


def pack_int4(x: torch.Tensor) -> torch.Tensor:
# Given int8 x, pack adjacent int4 values into a single int8.
low_x = x[:, ::2]
high_x = x[:, 1::2]

# High bits need to left shift, this also masks off extra bits.
high_x = torch.bitwise_left_shift(high_x, 4)
# Low bits need to have sign bits removed.
low_x = torch.bitwise_and(low_x, 0xF)

# Recombine into a single value with bitwise or.
return torch.bitwise_or(low_x, high_x).contiguous()


@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
Expand Down Expand Up @@ -121,7 +157,7 @@ def test_f8f8bf16_rowwise_simple(self) -> None:
x = xq.bfloat16()
w = wq.bfloat16()

zq_ref = (x @ w.T).to(torch.bfloat16) * x_scale[:, None] * w_scale[None, :]
zq_ref = (x @ w.T * x_scale[:, None] * w_scale[None, :]).to(torch.bfloat16)

torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3)

Expand Down Expand Up @@ -216,6 +252,46 @@ def test_quantize_fp8_matmul(
zq_ref += bias
torch.testing.assert_close(zq, zq_ref, atol=8.0e-2, rtol=8.0e-2)

@unittest.skipIf(
not torch.version.cuda, "Skip on AMD: built in quantize ops not yet suported."
)
@settings(deadline=None)
@given(
B_T=st.sampled_from([2048, 4096]),
D=st.sampled_from([128, 256]),
HD_L=st.sampled_from([256, 512]),
QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]),
CudaGraph=st.sampled_from([True, False]),
)
def test_quantize_int4_fp8_matmul(
self,
B_T: int,
D: int,
HD_L: int,
QType: torch.dtype,
CudaGraph: bool,
) -> None:
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01

wq, w_scale, w_zp = int4_row_quantize(w, 128)
wq = pack_int4(wq).contiguous().to(device="cuda")
w_scale = w_scale.contiguous().to(device="cuda")
w_zp = w_zp.contiguous().to(device="cuda")

if CudaGraph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
zq = torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)
g.replay()
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
zq = torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)

zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=8.0e-2, rtol=8.0e-2)

@unittest.skipIf(
not torch.version.cuda, "Skip on AMD: built in quantize ops not yet suported."
)
Expand Down

0 comments on commit 63ca6dc

Please sign in to comment.