From d3eae1da1b5b9a9ea6e24fcf8e0b4d71a64710e8 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Fri, 4 Oct 2024 09:42:54 -0700 Subject: [PATCH] MoE BMM INT4 rowwise weight-only (#3219) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3219 X-link: https://github.com/facebookresearch/FBGEMM/pull/316 Marlin int4 weight-only with loopover for bmm performs great (**up to 7x faster** compared to bf16 bmm) when dim M is small to medium size (e.g., < 256) in decode; For larger dim M, we could leverage this bmm int4 rowwise weight-only kernel in prefill that is around **1.5x faster** than marlin int4 loopover and maintain the same accuracy More results can be found in this [data sheet](https://docs.google.com/spreadsheets/d/12JWt3SqX_1GSLKwjGyt0KQl9SMWDF0r0C63MMKsE9JM/edit?usp=sharing) Reviewed By: jianyuh Differential Revision: D63818529 fbshipit-source-id: 127e841fa7c6c1ce810b6e8b6e35907eeaecafd6 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 1 + .../bf16i4bf16_rowwise_batched.cu | 298 ++++++++++++++++++ .../gen_ai/src/quantize/quantize.cpp | 28 +- .../gen_ai/test/quantize/quantize_test.py | 60 ++-- 4 files changed, 366 insertions(+), 21 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 5accb9c53..dd6f165ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -49,6 +49,7 @@ else() src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu + src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu src/quantize/quantize.cu src/quantize/quantize.cpp) endif() diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu new file mode 100644 index 000000000..871543a2f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + typename WEIGHT_SCALE_DTYPE> +at::Tensor bf16i4bf16_rowwise_batched_impl( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + int K = X.size(2); + + int num_groups = w_scale.size(0) / B; + + TORCH_CHECK(X.is_cuda() && X.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); + TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); + TORCH_CHECK(K >= num_groups && K % num_groups == 0); + + int group_size = K / num_groups; + + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::bfloat16_t; + using LayoutInputA = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::int4b_t; + using LayoutInputB = cutlass::layout::RowMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementScale = WEIGHT_SCALE_DTYPE; + using ElementZeroPoint = WEIGHT_SCALE_DTYPE; + using ElementComputeEpilogue = float; + using ElementAccumulator = float; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using PongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + cute::tuple, + LayoutInputB, + AlignmentInputB, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + using StrideS = typename CollectiveMainloop::StrideScale; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, B)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(N, num_groups, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, B}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(X.data_ptr()), + stride_a, + reinterpret_cast(w_scale.data_ptr()), + stride_S, + group_size, + reinterpret_cast(w_zp.data_ptr())}, + {{1.0, 0.0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + KernelMode kernel = get_batched_kernel_mode(X, WQ); + if (kernel == KernelMode::Small) { + return bf16i4bf16_rowwise_batched_impl< + 64, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else if (kernel == KernelMode::Large) { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + false, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } +} + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // Check datatypes. + TORCH_CHECK( + (w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || + (w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || + (w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), + "Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); + + if (w_scale.dtype() == at::kFloat) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kHalf) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kBFloat16) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else { + throw std::runtime_error( + "Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched"); + } +} + +#else + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 101a5cba1..1abf8fb40 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -96,6 +96,11 @@ at::Tensor bf16i4bf16_rowwise( at::Tensor WQ, at::Tensor w_scale, at::Tensor w_zp); +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, + at::Tensor WQ, + at::Tensor w_scale, + at::Tensor w_zp); at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); @@ -152,6 +157,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); + m.def( + "bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); + m.def( "i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor"); m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); @@ -326,14 +335,28 @@ at::Tensor f8i4bf16_rowwise_meta( at::Tensor bf16i4bf16_rowwise_meta( at::Tensor X, // BF16 at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { int M = X.size(0); int N = WQ.size(0); auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); return Y; } +at::Tensor bf16i4bf16_rowwise_batched_meta( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + std::vector quantize_fp8_per_row_meta( at::Tensor input, std::optional bs, @@ -370,6 +393,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index c21c1713a..38a09f360 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -673,6 +673,7 @@ def fp8_loopover_bmm( M=st.sampled_from([2048, 4096]), N=st.sampled_from([256, 512]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_int4_batched_gemm( self, @@ -680,6 +681,7 @@ def test_int4_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: if not MARLIN_ENABLED: return @@ -689,28 +691,48 @@ def test_int4_batched_gemm( wq = [] w_scale = [] group_size = 128 - for i in range(B): - _, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size) - wq.append(wq_) - w_scale.append(w_scale_) - wq = torch.stack(wq) - w_scale = torch.stack(w_scale) - - def int4_loopover_bmm( - x: torch.Tensor, - wq: torch.Tensor, - w_scale: torch.Tensor, - ) -> torch.Tensor: - B = x.shape[0] - M = x.shape[1] - N = w_scale.shape[2] - y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + + if use_loopover: for i in range(B): - y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) - return y + _, wq_, w_scale_ = marlin_quantize( + w[i].cuda().t().contiguous(), group_size + ) + wq.append(wq_) + w_scale.append(w_scale_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale) + + def int4_loopover_bmm( + x: torch.Tensor, + wq: torch.Tensor, + w_scale: torch.Tensor, + ) -> torch.Tensor: + B = x.shape[0] + M = x.shape[1] + N = w_scale.shape[2] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + for i in range(B): + y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) + return y + + y_int4 = int4_loopover_bmm(x, wq, w_scale) + else: + w_zp = [] + for i in range(B): + wq_, w_scale_, w_zp_ = int4_row_quantize(w[i], group_size) + + wq_ = pack_int4(wq_).contiguous().to(device="cuda") + w_scale_ = w_scale_.contiguous().to(device="cuda") + w_zp_ = w_zp_.contiguous().to(device="cuda") + wq.append(wq_) + w_scale.append(w_scale_) + w_zp.append(w_zp_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale).view(-1, N) + w_zp = torch.stack(w_zp).view(-1, N) + y_int4 = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, wq, w_scale, w_zp) y_ref = torch.bmm(x, w.transpose(1, 2)) - y_int4 = int4_loopover_bmm(x, wq, w_scale) torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2)