Skip to content

Commit

Permalink
MoE BMM INT4 rowwise weight-only (#3219)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3219

X-link: facebookresearch/FBGEMM#316

Marlin int4 weight-only with loopover for bmm performs great (**up to 7x faster** compared to bf16 bmm) when dim M is small to medium size (e.g., < 256) in decode; For larger dim M, we could leverage this bmm int4 rowwise weight-only kernel in prefill that is around **1.5x faster** than marlin int4 loopover and maintain the same accuracy

More results can be found in this [data sheet](https://docs.google.com/spreadsheets/d/12JWt3SqX_1GSLKwjGyt0KQl9SMWDF0r0C63MMKsE9JM/edit?usp=sharing)

Reviewed By: jianyuh

Differential Revision: D63818529

fbshipit-source-id: 127e841fa7c6c1ce810b6e8b6e35907eeaecafd6
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Oct 4, 2024
1 parent a0966e8 commit d3eae1d
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 21 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ else()
src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu
src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu
src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu
src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu
src/quantize/quantize.cu
src/quantize/quantize.cpp)
endif()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

// clang-format off
// The fixed ordering of the headers is required for CUTLASS 3.2+
#include <cute/tensor.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
// clang-format on

#include "cutlass_extensions/include/kernel_mode.h"

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
typename WEIGHT_SCALE_DTYPE>
at::Tensor bf16i4bf16_rowwise_batched_impl(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
// XQ: B x M x K
// WQ: B x N x K
// output: B x M x N
int B = X.size(0);
int M = X.size(1);
int N = WQ.size(1);
int K = X.size(2);

int num_groups = w_scale.size(0) / B;

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous());
TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous());
TORCH_CHECK(K >= num_groups && K % num_groups == 0);

int group_size = K / num_groups;

auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::int4b_t;
using LayoutInputB = cutlass::layout::RowMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementScale = WEIGHT_SCALE_DTYPE;
using ElementZeroPoint = WEIGHT_SCALE_DTYPE;
using ElementComputeEpilogue = float;
using ElementAccumulator = float;

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementInputB, ElementScale, ElementZeroPoint>,
LayoutInputB,
AlignmentInputB,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideS = typename CollectiveMainloop::StrideScale;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, B));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, B));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, B));
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, B));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, B},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(X.data_ptr()),
stride_a,
reinterpret_cast<ElementScale*>(w_scale.data_ptr()),
stride_S,
group_size,
reinterpret_cast<ElementZeroPoint*>(w_zp.data_ptr())},
{{1.0, 0.0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};

Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());

if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

template <typename WEIGHT_SCALE_DTYPE>
at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
KernelMode kernel = get_batched_kernel_mode(X, WQ);
if (kernel == KernelMode::Small) {
return bf16i4bf16_rowwise_batched_impl<
64,
128,
64,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
64,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
64,
2,
1,
1,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
}
}

at::Tensor bf16i4bf16_rowwise_batched(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
// Check datatypes.
TORCH_CHECK(
(w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) ||
(w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) ||
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same .");

if (w_scale.dtype() == at::kFloat) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<float>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kHalf) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::half_t>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kBFloat16) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::bfloat16_t>(
X, WQ, w_scale, w_zp);
} else {
throw std::runtime_error(
"Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched");
}
}

#else

at::Tensor bf16i4bf16_rowwise_batched(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

} // namespace fbgemm_gpu
28 changes: 26 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ at::Tensor bf16i4bf16_rowwise(
at::Tensor WQ,
at::Tensor w_scale,
at::Tensor w_zp);
at::Tensor bf16i4bf16_rowwise_batched(
at::Tensor X,
at::Tensor WQ,
at::Tensor w_scale,
at::Tensor w_zp);

at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale);
std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(at::Tensor X);
Expand Down Expand Up @@ -152,6 +157,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);

m.def(
"bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);

m.def(
"i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor");
m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic);
Expand Down Expand Up @@ -326,14 +335,28 @@ at::Tensor f8i4bf16_rowwise_meta(
at::Tensor bf16i4bf16_rowwise_meta(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
at::Tensor /* w_scale */,
at::Tensor /* w_zp */
) {
int M = X.size(0);
int N = WQ.size(0);
auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor bf16i4bf16_rowwise_batched_meta(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor /* w_scale */,
at::Tensor /* w_zp */
) {
int B = X.size(0);
int M = X.size(1);
int N = WQ.size(1);
auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

std::vector<at::Tensor> quantize_fp8_per_row_meta(
at::Tensor input,
std::optional<at::Tensor> bs,
Expand Down Expand Up @@ -370,6 +393,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
#endif
}

Expand Down
Loading

0 comments on commit d3eae1d

Please sign in to comment.