From e3d6b1ed37f47508f40f94bfbd337b7c6695c1ec Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Tue, 1 Oct 2024 17:35:14 -0700 Subject: [PATCH] MoE BMM FP8 rowwise Summary: Enable MoE BMM FP8 rowwise: - MoE BMM FP8 rowwise achieves **up to 4.5x (2.1x on average) speedup compared to BF16 BMM** - In E2E with MoE 16b x 16, FP8 with BMM achieves **1.2x speedup than BF16** - Integrated in E2E and verified correctness which matches BF16 generations - More results are in the [data sheet](https://docs.google.com/spreadsheets/d/1OLdz4MlzWS9pdgTBq4Jjy0-9_nPn-NmdrMolY0jZOXE/edit?gid=0#gid=0) {F1903027122} Differential Revision: D63681109 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 1 + .../f8f8bf16_rowwise_batched.cu | 502 ++++++++++++++++++ .../cutlass_extensions/include/kernel_mode.h | 21 + .../gen_ai/src/quantize/quantize.cpp | 27 + .../gen_ai/test/quantize/quantize_test.py | 10 +- 5 files changed, 560 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index a2916c9e5..5accb9c53 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -43,6 +43,7 @@ else() src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu src/quantize/cutlass_extensions/f8f8bf16_cublas.cu src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu + src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu src/quantize/cutlass_extensions/i8i8bf16.cu src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu new file mode 100644 index 000000000..21c68257d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -0,0 +1,502 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +// Cutlass rowwise batched kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +at::Tensor f8f8bf16_rowwise_batched_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = WQ.size(2); + TORCH_CHECK(XQ.size(-1) == K); + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + // Make sure the provided output has the proper shape and dtype. + TORCH_CHECK(Y.sizes().vec() == out_sizes); + TORCH_CHECK(Y.dtype() == at::kBFloat16); + } else { + Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, int32_t>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, int32_t>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, int32_t>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, B}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr()), + ElementBias(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // bias + // compute_1 + { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +// FP8 Rowwise batched Cutlass kernel dispatch. +template +at::Tensor dispatch_fp8_rowwise_batched_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + KernelMode kernel = get_batched_kernel_mode(XQ, WQ); + TORCH_CHECK( + (XQ.dim() == 3 && WQ.dim() == 3), + "FP8 rowwise batched GEMM only supports 3D inputs"); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } +} + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } +} + +#else + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 94a68096d..93b96fb04 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -31,4 +31,25 @@ inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { } } +inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto B = XQ.size(0); + auto M = XQ.size(1); + auto K = XQ.size(2); + auto N = WQ.size(1); + auto BM = B * M; + auto BN = B * N; + auto BK = B * K; + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) || + (BK >= 2048 && BN >= 2048)); + if (BM <= 128 || BN <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 39084712c..ff5c66766 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -57,6 +57,14 @@ at::Tensor f8f8bf16_rowwise( std::optional bias = c10::nullopt, bool use_fast_accum = true, std::optional output = c10::nullopt); +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -132,6 +140,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); + m.def( + "f8f8bf16_rowwise_batched(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); @@ -188,6 +198,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); #endif } @@ -216,6 +227,21 @@ at::Tensor f8f8bf16_rowwise_meta( return Y; } +at::Tensor f8f8bf16_rowwise_batched_meta( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor /* x_scale */, + at::Tensor /* w_scale */, + std::optional /* bias = c10::nullopt */, + bool /* use_fast_accum = true */, + std::optional /* output = c10::nullopt */) { + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, XQ.options().dtype(at::kBFloat16)); + return Y; +} + at::Tensor f8f8bf16_blockwise_meta( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -331,6 +357,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); #endif diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index fdd038e2c..c21c1713a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -614,12 +614,16 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: BMM ops are not yet suported." + ) @settings(deadline=None) @given( B=st.sampled_from([1, 4]), M=st.sampled_from([2048, 4096]), N=st.sampled_from([128, 256]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_fp8_batched_gemm( self, @@ -627,6 +631,7 @@ def test_fp8_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 @@ -655,7 +660,10 @@ def fp8_loopover_bmm( return y y_ref = torch.bmm(x, w.transpose(1, 2)) - y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + if use_loopover: + y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + else: + y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale) torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) @unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.")