From 8e335d3940319cf1a722a35b00513ffba2133669 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 21 May 2024 20:24:16 -0700 Subject: [PATCH] Add AMD Rowwise FP8 Matmul (#2611) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2611 This diff extends the `fp8fp8bf16_rowwise` gemm operation to AMD through a new CK kernel. The new kernel requires new stride support that is only available in developer branches of CK, so we must rely on the `ai_codesign/gen_ai` CK repo. I also extend the fp8 benchmarking suite to include rowwise measurements. I'll soon add detailed benchmarking results but the quick summary is that performance looks quite good, typically inline with tensorwise quantization and sometimes faster, presumably due to using the latest and greatest CK pipelines. Full performance results can be found on the rowwise sheet linked below. Overall, the rowwise kernel is similarly performant to tensorwise scaling and actually does quite a bit better in some cases. https://docs.google.com/spreadsheets/d/1hK62EJIt9mZQJdZBxf5rGZ3u0hjMng7AkpwKnIjyn3k/edit#gid=881561326 Reviewed By: jianyuh, jiawenliu64 Differential Revision: D57600068 fbshipit-source-id: c5c533f1d332de293f788792816e535ab1131cec --- .../experimental/gen_ai/bench/ck_fp8_bench.py | 107 +++++++-- .../gen_ai/src/quantize/ck_extensions.hip | 223 +++++++++++++++++- .../gen_ai/src/quantize/cutlass_extensions.cu | 5 - .../gen_ai/src/quantize/quantize.cpp | 9 +- .../gen_ai/test/quantize/quantize_test.py | 39 ++- 5 files changed, 334 insertions(+), 49 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py index 66b75b823..96b058685 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py @@ -31,6 +31,20 @@ def set_amd_env_vars() -> None: os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30" +@torch.no_grad() +def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Quantize an input tensor and return the fp8 tensor and its inverse scale. + x_row_max = torch.max(torch.abs(x), dim=1).values + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + scale = E4M3_MAX_POS / torch.clamp(x_row_max, EPS) + # pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`. + xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to( + torch.float8_e4m3fnuz + ) + # pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`. + return xq, scale.to(torch.float32).reciprocal() + + def fp8_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: f8_max = torch.tensor(E4M3_MAX_POS, device=x.device) x_amax = torch.max(torch.abs(x)) @@ -65,12 +79,22 @@ def forward( return output -class CKMatmul(torch.nn.Module): +class CKTensorMatmul(torch.nn.Module): def forward( self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: - out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale) - return out + return torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale) + + +class CKRowMatmul(torch.nn.Module): + def forward( + self, + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.fbgemm.f8f8bf16_rowwise(a, b, a_scale, b_scale) @torch.no_grad() @@ -82,13 +106,18 @@ def evaluate_impl( baseline_func: Callable[ [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor ], - ck_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -) -> Tuple[float, float, float, float, float]: + ck_tensor_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], + ck_row_func: Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor + ], +) -> Tuple[float, float, float, float, float, float, float]: print(f"Evaluating {M=}, {N=}, {K=}") A = torch.randn(M, K).to(dtype=torch.bfloat16, device="cuda") QA, a_scale = fp8_quantize(A) B = torch.randn(N, K).to(dtype=torch.bfloat16, device="cuda") QB, b_scale = fp8_quantize(B) + QA_row, a_scale_row = fp8_row_quantize(A) + QB_row, b_scale_row = fp8_row_quantize(B) # Check accuracy. out_ref = fp_func(A.to(torch.float32), B.t().to(torch.float32)) @@ -97,9 +126,13 @@ def evaluate_impl( baseline_sim = torch.mean(torch.pow(torch.abs(baseline_out - out_ref), 2)) print(f"Baseline accuracy: {baseline_sim}") - ck_out = ck_func(QA, QB, a_scale * b_scale) - ck_sim = torch.mean(torch.pow(torch.abs(ck_out - out_ref), 2)) - print(f"CK accuracy: {ck_sim}") + ck_tensor_out = ck_tensor_func(QA, QB, a_scale * b_scale) + ck_tensor_sim = torch.mean(torch.pow(torch.abs(ck_tensor_out - out_ref), 2)) + print(f"CK tensorwise accuracy: {ck_tensor_sim}") + + ck_row_out = ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row) + ck_row_sim = torch.mean(torch.pow(torch.abs(ck_row_out - out_ref), 2)) + print(f"CK rowwise accuracy: {ck_row_sim}") # Benchmark runtimes. ms_ref: float = triton.testing.do_bench(lambda: fp_func(A, B.t())) @@ -110,10 +143,25 @@ def evaluate_impl( ) print(f"Baseline runtime: {ms_baseline} ms") - ms_ck: float = triton.testing.do_bench(lambda: ck_func(QA, QB, a_scale * b_scale)) - print(f"CK runtime: {ms_ck} ms") + ms_tensor_ck: float = triton.testing.do_bench( + lambda: ck_tensor_func(QA, QB, a_scale * b_scale) + ) + print(f"CK tensorwise runtime: {ms_tensor_ck} ms") - return float(baseline_sim.item()), float(ck_sim.item()), ms_baseline, ms_ck, ms_ref + ms_row_ck: float = triton.testing.do_bench( + lambda: ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row) + ) + print(f"CK rowwise runtime: {ms_row_ck} ms") + + return ( + float(baseline_sim.item()), + float(ck_tensor_sim.item()), + float(ck_row_sim.item()), + ms_baseline, + ms_tensor_ck, + ms_row_ck, + ms_ref, + ) def main(args: Any) -> None: @@ -121,12 +169,19 @@ def main(args: Any) -> None: set_amd_env_vars() with torch.no_grad(): - ck_mod = CKMatmul() + ck_tensor_mod = CKTensorMatmul() + ck_row_mod = CKRowMatmul() baseline_mod = BaselineMatmul() bf16_mod = FPMatMul() if args.torch_compile_mode: - ck_mod = torch.compile( - ck_mod, + ck_tensor_mod = torch.compile( + ck_tensor_mod, + dynamic=False, + backend="inductor", + mode=args.torch_compile_mode, + ) + ck_row_mod = torch.compile( + ck_row_mod, dynamic=False, backend="inductor", mode=args.torch_compile_mode, @@ -147,15 +202,23 @@ def main(args: Any) -> None: benchmark_results = [] # Test over a bunch of shapes. - M = [13312, 16384, 16032, 2304, 2048] - N = [4096, 2304, 13312, 8192] - K = [16384, 6656, 2304, 2048, 13312] + M = [128, 2048, 2304, 13312, 16032, 16384] + N = [128, 2304, 4096, 8192, 13312] + K = [128, 2048, 2304, 6656, 13312, 16384] for m in M: for n in N: for k in K: - baseline_sim, ck_sim, ms_baseline, ms_ck, ms_bf16 = evaluate_impl( - m, n, k, bf16_mod, baseline_mod, ck_mod + ( + baseline_sim, + ck_tensor_sim, + ck_row_sim, + ms_baseline, + ms_tensor_ck, + ms_row_ck, + ms_bf16, + ) = evaluate_impl( + m, n, k, bf16_mod, baseline_mod, ck_tensor_mod, ck_row_mod ) benchmark_results.append( { @@ -163,9 +226,11 @@ def main(args: Any) -> None: "N": n, "K": k, "baseline_sim": baseline_sim, - "ck_sim": ck_sim, + "ck_tensor_sim": ck_tensor_sim, + "ck_row_sim": ck_row_sim, "ms_baseline": ms_baseline, - "ms_ck": ms_ck, + "ms_tensor_ck": ms_tensor_ck, + "ms_row_ck": ms_row_ck, "ms_bf16": ms_bf16, } ) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip index 17c18e9cc..e94c5ec9d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" @@ -32,6 +33,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" // Define commonly used types. template @@ -51,6 +53,24 @@ struct Scale { float scale_; }; +struct RowwiseScale { + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void + operator()( + ck::bhalf_t& e, + const float& c, + const float& d0, + const float& d1) const { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } +}; + namespace fbgemm_gpu { template < @@ -61,17 +81,13 @@ template < int MPER_WAVE, int NPER_WAVE, bool PADDING = false> -at::Tensor f8f8bf16_tensorwise_impl(at::Tensor XQ, at::Tensor WQ, double scale) { +at::Tensor +f8f8bf16_tensorwise_impl(at::Tensor XQ, at::Tensor WQ, double scale) { // Get input information. int M = XQ.size(0); int N = WQ.size(0); int K = XQ.size(1); - // Check that sizes are sufficiently large for grid dispatch. - TORCH_CHECK( - M >= 128 && N >= 128 && K >= 256, - "Minimum supported M,N,K is 128,128,256."); - int StrideA = K; int StrideB = K; int StrideC = N; @@ -211,17 +227,28 @@ std::tuple get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { } } -at::Tensor -f8f8bf16_tensorwise(at::Tensor XQ, at::Tensor WQ, double scale, bool use_fast_accum) { +at::Tensor f8f8bf16_tensorwise( + at::Tensor XQ, + at::Tensor WQ, + double scale, + bool use_fast_accum) { + // Check that input types are compatible with AMD FP8. + TORCH_CHECK( + (XQ.dtype() == at::kFloat8_e4m3fnuz) && + (WQ.dtype() == at::kFloat8_e4m3fnuz), + "Inputs must be type float8_e4m3fnuz."); TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum"); auto [kernel, pad] = get_kernel_mode(XQ, WQ); if (pad) { if (kernel == KernelMode::Small) { - return f8f8bf16_tensorwise_impl<64, 32, 64, 64, 1, 2, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<64, 32, 64, 64, 1, 2, true>( + XQ, WQ, scale); } else if (kernel == KernelMode::Large) { - return f8f8bf16_tensorwise_impl<256, 256, 128, 64, 4, 2, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<256, 256, 128, 64, 4, 2, true>( + XQ, WQ, scale); } else { - return f8f8bf16_tensorwise_impl<256, 128, 128, 64, 2, 2, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<256, 128, 128, 64, 2, 2, true>( + XQ, WQ, scale); } } else { if (kernel == KernelMode::Small) { @@ -234,6 +261,180 @@ f8f8bf16_tensorwise(at::Tensor XQ, at::Tensor WQ, double scale, bool use_fast_ac } } +template < + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int MPER_WAVE, + int NPER_WAVE, + bool PADDING = false> +at::Tensor f8f8bf16_rowwise_impl( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale) { + // Get input information. + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + int StrideA = K; + int StrideB = K; + int StrideE = N; + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ADataType = ck::f8_t; + using BDataType = ck::f8_t; + using D0DataType = float; + using D1DataType = float; + using DsDataType = ck::Tuple; + using EDataType = ck::bhalf_t; + using AccDataType = float; + using CShuffleDataType = float; + + using ALayout = Row; + using BLayout = Col; + using D0Layout = Row; + using D1Layout = Col; + using DsLayout = ck::Tuple; + using ELayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = RowwiseScale; + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + using ComputeType = ck::f8_t; + + // Define derivative constants based on template parameters. + static constexpr int BLOCK_CLUSTER = BLOCK_SIZE / 4; + static constexpr int CBLOCK_N = NBLOCK / 16; + static constexpr int CBLOCK_M = BLOCK_SIZE / CBLOCK_N; + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + BLOCK_SIZE, // Block Size + MBLOCK, // M per Block + NBLOCK, // N per Block + KBLOCK, // K per Block + 16, // AK1 + 16, // BK1 + 32, // M per Xdl + 32, // N per Xdl + MPER_WAVE, // Mxdl per Wave + NPER_WAVE, // Nxdl per Wave + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 0, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 0, + 1, + 1, + S<1, CBLOCK_M, 1, CBLOCK_N>, + S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ComputeType>; + + // Create gemm launcher and arguments. + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto I0 = + ck::Number<0>{}; // Used to indicate 0 stride for row and col broadcast. + + auto argument = gemm.MakeArgument( + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + std::array{ + reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr())}, + reinterpret_cast(Y.data_ptr()), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); + + return Y; +} + +at::Tensor f8f8bf16_rowwise( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + c10::optional bias, + bool use_fast_accum) { + // Check that input datatypes are valid. + TORCH_CHECK( + (XQ.dtype() == at::kFloat8_e4m3fnuz) && + (WQ.dtype() == at::kFloat8_e4m3fnuz), + "Inputs must be type float8_e4m3fnuz."); + TORCH_CHECK((x_scale.dtype() == at::kFloat) && (w_scale.dtype() == at::kFloat), "Scales must be float32."); + TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum."); + TORCH_CHECK(!(bias.has_value()), "AMD does not yet support bias."); + auto [kernel, pad] = get_kernel_mode(XQ, WQ); + if (pad) { + if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, true>(XQ, WQ, x_scale, w_scale); + } else { + return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, true>(XQ, WQ, x_scale, w_scale); + } + } else { + if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, false>(XQ, WQ, x_scale, w_scale); + } else { + return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, false>(XQ, WQ, x_scale, w_scale); + } + } +} + } // namespace fbgemm_gpu #endif // defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu index ef44aedb1..dd76be16e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu @@ -1326,11 +1326,6 @@ at::Tensor f8f8bf16_rowwise( bias.value().dtype() == at::kBFloat16, "Bias type must be bfloat16 or float32 if provided."); } - // Extract problem size. - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - bool use_bias = bias.has_value(); bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 7c1d80b3b..91abab3a3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -105,9 +105,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor"); - m.def( - "f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); - m.def( "f8f8bf16_cublas(Tensor A, Tensor B, Tensor Ainvs, Tensor Binvs, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); @@ -120,6 +117,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); #endif + m.def( + "f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); m.def( "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); @@ -164,9 +163,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise); + m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); - m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); @@ -255,9 +254,9 @@ at::Tensor f8i4bf16_rowwise_meta( TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta); + m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16_meta); - m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta); m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 38b5e158a..d34a78d41 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -36,7 +36,7 @@ def fp8_row_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_row_max = torch.max(torch.abs(x), dim=1).values max_scaling_factor = E4M3_MAX_POS * 512.0 # Match kernel logics scale = torch.Tensor(E4M3_MAX_POS / x_row_max).clamp(max=max_scaling_factor) - xq = (x * scale.unsqueeze(1)).to(torch.float8_e4m3fn) + xq = (x * scale.unsqueeze(1)).to(fp8_e4m3) return xq, scale.reciprocal().to(torch.float32) @@ -45,7 +45,7 @@ def fp8_col_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_col_max = torch.max(torch.abs(x), dim=0).values max_scaling_factor = E4M3_MAX_POS * 512.0 # Match kernel logics scale = torch.Tensor(E4M3_MAX_POS / x_col_max).clamp(max=max_scaling_factor) - xq = (x * scale.unsqueeze(0)).to(torch.float8_e4m3fn) + xq = (x * scale.unsqueeze(0)).to(fp8_e4m3) return xq, scale.reciprocal().to(torch.float32) @@ -58,8 +58,8 @@ class FP8Tests(unittest.TestCase): def test_fp8_python(self) -> None: src_float = torch.randn(1000, 1000).cuda() src_float[0, 0] = 1e6 - fp8_152 = src_float.to(torch.float8_e5m2) - fp8_143 = src_float.to(torch.float8_e4m3fn) + fp8_152 = src_float.to(fp8_e5m2) + fp8_143 = src_float.to(fp8_e4m3) assert len(fp8_152.float().unique()) <= 256 assert len(fp8_143.float().unique()) <= 256 @@ -102,6 +102,31 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @unittest.skipIf( + ((not torch.version.cuda) and (not torch.version.hip)), + "Skip if no GPU is present.", + ) + @settings(deadline=None) + def test_f8f8bf16_rowwise_simple(self) -> None: + M = 128 + N = 128 + K = 256 + x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01 + + xq, x_scale = fp8_row_quantize_ref(x) + wq, w_scale = fp8_row_quantize_ref(w) + + zq = torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) + + # Fake quant + x = xq.bfloat16() + w = wq.bfloat16() + + zq_ref = (x @ w.T).to(torch.bfloat16) * x_scale[:, None] * w_scale[None, :] + + torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @unittest.skipIf( not torch.version.cuda, "Skip on AMD: built in quantize ops not yet suported." ) @@ -111,7 +136,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: D=st.sampled_from([128, 256]), HD_L=st.sampled_from([256, 512]), Mode=st.sampled_from(["tensorwise", "tensorwise_broadcast", "rowwise"]), - QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]), + QType=st.sampled_from([fp8_e4m3, fp8_e5m2]), Bias=st.sampled_from([True, False]), CudaGraph=st.sampled_from([True, False]), ) @@ -204,7 +229,7 @@ def test_quantize_fp8_matmul( ) def test_quantize_fp8_per_tensor_row_col(self, B_T: int, D: int, Mode: str) -> None: x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1 - fp8_max = torch.finfo(torch.float8_e4m3fn).max + fp8_max = torch.finfo(fp8_e4m3).max if Mode == "tensorwise": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) @@ -212,7 +237,7 @@ def test_quantize_fp8_per_tensor_row_col(self, B_T: int, D: int, Mode: str) -> N xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) x_max = x.abs().max() x_scale_ref = (x_max / fp8_max).float() - xq_ref = (x * fp8_max / x_max).to(torch.float8_e4m3fn) + xq_ref = (x * fp8_max / x_max).to(fp8_e4m3) elif Mode == "rowwise": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x) x = (xq.float() / x_scale.unsqueeze(1)).bfloat16() # Fake quantization